diff --git a/src/api.rs b/src/api.rs index 26aef7d..fca997a 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,9 +1,7 @@ -use std::cmp::max; use std::collections::HashMap; use std::ffi::CString; use std::ffi::c_char; use std::ffi::c_void; -use std::fmt::Display; use std::path::Path; use std::path::PathBuf; use std::ptr::null; @@ -215,8 +213,10 @@ pub struct DbCacheParams { warmup: i32, /// Steps Computation Mask controls which steps can be cached - #[builder(default = "ScmPreset::default()")] - scm_mask: ScmPreset, + /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1" + /// where 1 means compute, 0 means cache + #[builder(default = "CLibString::default()")] + scm_mask: CLibString, /// Scm Policy #[builder(default = "ScmPolicy::default()")] @@ -233,133 +233,6 @@ pub enum ScmPolicy { Dynamic, } -/// Steps Computation Mask Preset controls which steps can be cached -#[derive(Debug, Default, Clone)] -pub enum ScmPreset { - Slow, - #[default] - Medium, - Fast, - Ultra, - /// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1" - /// where 1 means compute, 0 means cache - Custom(String), -} - -impl ScmPreset { - fn to_vec_string(&self, steps: i32) -> String { - match self { - ScmPreset::Slow => ScmPresetBins { - compute_bins: vec![8, 3, 3, 2, 1, 1], - cache_bins: vec![1, 2, 2, 2, 3], - steps, - } - .to_string(), - ScmPreset::Medium => ScmPresetBins { - compute_bins: vec![6, 2, 2, 2, 2, 1], - cache_bins: vec![1, 3, 3, 3, 3], - steps, - } - .to_string(), - ScmPreset::Fast => ScmPresetBins { - compute_bins: vec![6, 1, 1, 1, 1, 1], - cache_bins: vec![1, 3, 4, 5, 4], - steps, - } - .to_string(), - ScmPreset::Ultra => ScmPresetBins { - compute_bins: vec![4, 1, 1, 1, 1], - cache_bins: vec![2, 5, 6, 7], - steps, - } - .to_string(), - ScmPreset::Custom(s) => s.clone(), - } - } -} - -#[derive(Debug, Clone)] -struct ScmPresetBins { - compute_bins: Vec, - cache_bins: Vec, - steps: i32, -} - -impl ScmPresetBins { - fn maybe_scale(&self) -> ScmPresetBins { - if self.steps != 28 && self.steps > 0 { - return self.scale(); - } - self.clone() - } - - fn scale(&self) -> ScmPresetBins { - let scale = self.steps as f32 / 28.0; - let scaled_compute_bins = self - .compute_bins - .iter() - .map(|b| max(1, (*b as f32 * scale * 0.5) as i32)) - .collect(); - let scaled_cached_bins = self - .cache_bins - .iter() - .map(|b| max(1, (*b as f32 * scale * 0.5) as i32)) - .collect(); - ScmPresetBins { - compute_bins: scaled_compute_bins, - cache_bins: scaled_cached_bins, - steps: self.steps, - } - } - - fn generate_vec_mask(&self) -> Vec { - let mut mask = Vec::new(); - let mut c_idx = 0; - let mut cache_idx = 0; - - while mask.len() < self.steps as usize { - if c_idx < self.compute_bins.len() { - let compute_count = self.compute_bins[c_idx]; - for _ in 0..compute_count { - if mask.len() < self.steps as usize { - mask.push(1); - } - } - c_idx += 1; - } - if cache_idx < self.cache_bins.len() { - let cache_count = self.cache_bins[cache_idx]; - for _ in 0..cache_count { - if mask.len() < self.steps as usize { - mask.push(0); - } - } - cache_idx += 1; - } - if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() { - break; - } - } - if let Some(last) = mask.last_mut() { - *last = 1; - } - mask - } -} - -impl Display for ScmPresetBins { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mask: String = self - .maybe_scale() - .generate_vec_mask() - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(","); - write!(f, "{mask}") - } -} - /// Config struct for a specific diffusion model #[derive(Builder, Debug, Clone)] #[builder( @@ -980,10 +853,7 @@ pub struct Config { disable_auto_resize_ref_image: bool, #[builder(default = "Self::cache_init()", private)] - cache: sd_cache_params_t, - - #[builder(default = "None", private)] - scm_mask: Option, + cache: (sd_cache_params_t, Option), } impl ConfigBuilder { @@ -1003,44 +873,47 @@ impl ConfigBuilder { } } - fn cache_init() -> sd_cache_params_t { - sd_cache_params_t { - mode: sd_cache_mode_t::SD_CACHE_DISABLED, - reuse_threshold: 1.0, - start_percent: 0.15, - end_percent: 0.95, - error_decay_rate: 1.0, - use_relative_threshold: true, - reset_error_on_compute: true, - Fn_compute_blocks: 8, - Bn_compute_blocks: 0, - residual_diff_threshold: 0.08, - max_warmup_steps: 8, - max_cached_steps: -1, - max_continuous_cached_steps: -1, - taylorseer_n_derivatives: 1, - taylorseer_skip_interval: 1, - scm_mask: null(), - scm_policy_dynamic: true, - spectrum_w: 0.4, - spectrum_m: 3, - spectrum_lam: 1.0, - spectrum_window_size: 2, - spectrum_flex_window: 0.5, - spectrum_warmup_steps: 4, - spectrum_stop_percent: 0.9, - } + fn cache_init() -> (sd_cache_params_t, Option) { + ( + sd_cache_params_t { + mode: sd_cache_mode_t::SD_CACHE_DISABLED, + reuse_threshold: 1.0, + start_percent: 0.15, + end_percent: 0.95, + error_decay_rate: 1.0, + use_relative_threshold: true, + reset_error_on_compute: true, + Fn_compute_blocks: 8, + Bn_compute_blocks: 0, + residual_diff_threshold: 0.08, + max_warmup_steps: 8, + max_cached_steps: -1, + max_continuous_cached_steps: -1, + taylorseer_n_derivatives: 1, + taylorseer_skip_interval: 1, + scm_mask: null(), + scm_policy_dynamic: true, + spectrum_w: 0.4, + spectrum_m: 3, + spectrum_lam: 1.0, + spectrum_window_size: 2, + spectrum_flex_window: 0.5, + spectrum_warmup_steps: 4, + spectrum_stop_percent: 0.9, + }, + None, + ) } pub fn no_caching(&mut self) -> &mut Self { let mut cache = Self::cache_init(); - cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED; + cache.0.mode = sd_cache_mode_t::SD_CACHE_DISABLED; self.cache = Some(cache); self } pub fn spectrum_caching(&mut self, params: SpectrumCacheParams) -> &mut Self { - let mut cache = Self::cache_init(); + let (mut cache, mask) = Self::cache_init(); cache.mode = sd_cache_mode_t::SD_CACHE_SPECTRUM; cache.spectrum_w = params.w; cache.spectrum_m = params.m; @@ -1049,12 +922,12 @@ impl ConfigBuilder { cache.spectrum_flex_window = params.flex; cache.spectrum_warmup_steps = params.warmup; cache.spectrum_stop_percent = params.stop; - self.cache = Some(cache); + self.cache = Some((cache, mask)); self } pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self { - let mut cache = Self::cache_init(); + let (mut cache, mask) = Self::cache_init(); cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE; cache.reuse_threshold = params.threshold; cache.start_percent = params.start; @@ -1062,22 +935,22 @@ impl ConfigBuilder { cache.error_decay_rate = params.decay; cache.use_relative_threshold = params.relative; cache.reset_error_on_compute = params.reset; - self.cache = Some(cache); + self.cache = Some((cache, mask)); self } pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self { - let mut cache = Self::cache_init(); + let (mut cache, mask) = Self::cache_init(); cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE; cache.reuse_threshold = params.threshold; cache.start_percent = params.start; cache.end_percent = params.end; - self.cache = Some(cache); + self.cache = Some((cache, mask)); self } pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self { - let mut cache = Self::cache_init(); + let (mut cache, _) = Self::cache_init(); cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE; cache.Fn_compute_blocks = params.fn_blocks; cache.Bn_compute_blocks = params.bn_blocks; @@ -1087,25 +960,20 @@ impl ConfigBuilder { ScmPolicy::Static => false, ScmPolicy::Dynamic => true, }; - self.scm_mask = Some(Some(CLibString::from( - params - .scm_mask - .to_vec_string(self.steps.unwrap_or_default()), - ))); - - self.cache = Some(cache); + self.cache = Some((cache, Some(params.scm_mask))); self } pub fn taylor_seer_caching(&mut self) -> &mut Self { - let mut cache = Self::cache_init(); + let (mut cache, mask) = Self::cache_init(); cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER; - self.cache = Some(cache); + self.cache = Some((cache, mask)); self } pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self { - self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT; + self.db_cache_caching(params).cache.as_mut().unwrap().0.mode = + sd_cache_mode_t::SD_CACHE_CACHE_DIT; self } } @@ -1143,10 +1011,7 @@ impl From for ConfigBuilder { .preview_mode(value.preview_mode) .preview_noisy(value.preview_noisy) .preview_interval(value.preview_interval) - .cache(value.cache); - if let Some(scm_mask) = value.scm_mask { - builder.scm_mask(scm_mask.clone()); - } + .cache(value.cache.clone()); builder } } @@ -1458,8 +1323,8 @@ fn gen_img_maybe_progress( }) .collect(); - let mut cache = config.cache; - if let Some(scm_mask) = &config.scm_mask { + let mut cache = config.cache.0; + if let Some(scm_mask) = &config.cache.1 { cache.scm_mask = scm_mask.as_ptr(); } diff --git a/sys/stable-diffusion.cpp b/sys/stable-diffusion.cpp index d6dd6d7..630ee03 160000 --- a/sys/stable-diffusion.cpp +++ b/sys/stable-diffusion.cpp @@ -1 +1 @@ -Subproject commit d6dd6d7b555c233bb9bc9f20b4751eb8c9269743 +Subproject commit 630ee03f23bd9947f610dd9fe038c56c0ff9c2de