Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 51 additions & 186 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::cmp::max;
use std::collections::HashMap;
use std::ffi::CString;
use std::ffi::c_char;
use std::ffi::c_void;
use std::fmt::Display;
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null;
Expand Down Expand Up @@ -215,8 +213,10 @@ pub struct DbCacheParams {
warmup: i32,

/// Steps Computation Mask controls which steps can be cached
#[builder(default = "ScmPreset::default()")]
scm_mask: ScmPreset,
/// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
/// where 1 means compute, 0 means cache
#[builder(default = "CLibString::default()")]
scm_mask: CLibString,

/// Scm Policy
#[builder(default = "ScmPolicy::default()")]
Expand All @@ -233,133 +233,6 @@ pub enum ScmPolicy {
Dynamic,
}

/// Steps Computation Mask Preset controls which steps can be cached
#[derive(Debug, Default, Clone)]
pub enum ScmPreset {
Slow,
#[default]
Medium,
Fast,
Ultra,
/// E.g.: "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
/// where 1 means compute, 0 means cache
Custom(String),
}

impl ScmPreset {
fn to_vec_string(&self, steps: i32) -> String {
match self {
ScmPreset::Slow => ScmPresetBins {
compute_bins: vec![8, 3, 3, 2, 1, 1],
cache_bins: vec![1, 2, 2, 2, 3],
steps,
}
.to_string(),
ScmPreset::Medium => ScmPresetBins {
compute_bins: vec![6, 2, 2, 2, 2, 1],
cache_bins: vec![1, 3, 3, 3, 3],
steps,
}
.to_string(),
ScmPreset::Fast => ScmPresetBins {
compute_bins: vec![6, 1, 1, 1, 1, 1],
cache_bins: vec![1, 3, 4, 5, 4],
steps,
}
.to_string(),
ScmPreset::Ultra => ScmPresetBins {
compute_bins: vec![4, 1, 1, 1, 1],
cache_bins: vec![2, 5, 6, 7],
steps,
}
.to_string(),
ScmPreset::Custom(s) => s.clone(),
}
}
}

#[derive(Debug, Clone)]
struct ScmPresetBins {
compute_bins: Vec<i32>,
cache_bins: Vec<i32>,
steps: i32,
}

impl ScmPresetBins {
fn maybe_scale(&self) -> ScmPresetBins {
if self.steps != 28 && self.steps > 0 {
return self.scale();
}
self.clone()
}

fn scale(&self) -> ScmPresetBins {
let scale = self.steps as f32 / 28.0;
let scaled_compute_bins = self
.compute_bins
.iter()
.map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
.collect();
let scaled_cached_bins = self
.cache_bins
.iter()
.map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
.collect();
ScmPresetBins {
compute_bins: scaled_compute_bins,
cache_bins: scaled_cached_bins,
steps: self.steps,
}
}

fn generate_vec_mask(&self) -> Vec<i32> {
let mut mask = Vec::new();
let mut c_idx = 0;
let mut cache_idx = 0;

while mask.len() < self.steps as usize {
if c_idx < self.compute_bins.len() {
let compute_count = self.compute_bins[c_idx];
for _ in 0..compute_count {
if mask.len() < self.steps as usize {
mask.push(1);
}
}
c_idx += 1;
}
if cache_idx < self.cache_bins.len() {
let cache_count = self.cache_bins[cache_idx];
for _ in 0..cache_count {
if mask.len() < self.steps as usize {
mask.push(0);
}
}
cache_idx += 1;
}
if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
break;
}
}
if let Some(last) = mask.last_mut() {
*last = 1;
}
mask
}
}

impl Display for ScmPresetBins {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mask: String = self
.maybe_scale()
.generate_vec_mask()
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
write!(f, "{mask}")
}
}

/// Config struct for a specific diffusion model
#[derive(Builder, Debug, Clone)]
#[builder(
Expand Down Expand Up @@ -980,10 +853,7 @@ pub struct Config {
disable_auto_resize_ref_image: bool,

#[builder(default = "Self::cache_init()", private)]
cache: sd_cache_params_t,

#[builder(default = "None", private)]
scm_mask: Option<CLibString>,
cache: (sd_cache_params_t, Option<CLibString>),
}

impl ConfigBuilder {
Expand All @@ -1003,44 +873,47 @@ impl ConfigBuilder {
}
}

fn cache_init() -> sd_cache_params_t {
sd_cache_params_t {
mode: sd_cache_mode_t::SD_CACHE_DISABLED,
reuse_threshold: 1.0,
start_percent: 0.15,
end_percent: 0.95,
error_decay_rate: 1.0,
use_relative_threshold: true,
reset_error_on_compute: true,
Fn_compute_blocks: 8,
Bn_compute_blocks: 0,
residual_diff_threshold: 0.08,
max_warmup_steps: 8,
max_cached_steps: -1,
max_continuous_cached_steps: -1,
taylorseer_n_derivatives: 1,
taylorseer_skip_interval: 1,
scm_mask: null(),
scm_policy_dynamic: true,
spectrum_w: 0.4,
spectrum_m: 3,
spectrum_lam: 1.0,
spectrum_window_size: 2,
spectrum_flex_window: 0.5,
spectrum_warmup_steps: 4,
spectrum_stop_percent: 0.9,
}
fn cache_init() -> (sd_cache_params_t, Option<CLibString>) {
(
sd_cache_params_t {
mode: sd_cache_mode_t::SD_CACHE_DISABLED,
reuse_threshold: 1.0,
start_percent: 0.15,
end_percent: 0.95,
error_decay_rate: 1.0,
use_relative_threshold: true,
reset_error_on_compute: true,
Fn_compute_blocks: 8,
Bn_compute_blocks: 0,
residual_diff_threshold: 0.08,
max_warmup_steps: 8,
max_cached_steps: -1,
max_continuous_cached_steps: -1,
taylorseer_n_derivatives: 1,
taylorseer_skip_interval: 1,
scm_mask: null(),
scm_policy_dynamic: true,
spectrum_w: 0.4,
spectrum_m: 3,
spectrum_lam: 1.0,
spectrum_window_size: 2,
spectrum_flex_window: 0.5,
spectrum_warmup_steps: 4,
spectrum_stop_percent: 0.9,
},
None,
)
}

pub fn no_caching(&mut self) -> &mut Self {
let mut cache = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
cache.0.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
self.cache = Some(cache);
self
}

pub fn spectrum_caching(&mut self, params: SpectrumCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
let (mut cache, mask) = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_SPECTRUM;
cache.spectrum_w = params.w;
cache.spectrum_m = params.m;
Expand All @@ -1049,35 +922,35 @@ impl ConfigBuilder {
cache.spectrum_flex_window = params.flex;
cache.spectrum_warmup_steps = params.warmup;
cache.spectrum_stop_percent = params.stop;
self.cache = Some(cache);
self.cache = Some((cache, mask));
self
}

pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
let (mut cache, mask) = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
cache.reuse_threshold = params.threshold;
cache.start_percent = params.start;
cache.end_percent = params.end;
cache.error_decay_rate = params.decay;
cache.use_relative_threshold = params.relative;
cache.reset_error_on_compute = params.reset;
self.cache = Some(cache);
self.cache = Some((cache, mask));
self
}

pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
let (mut cache, mask) = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
cache.reuse_threshold = params.threshold;
cache.start_percent = params.start;
cache.end_percent = params.end;
self.cache = Some(cache);
self.cache = Some((cache, mask));
self
}

pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
let (mut cache, _) = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
cache.Fn_compute_blocks = params.fn_blocks;
cache.Bn_compute_blocks = params.bn_blocks;
Expand All @@ -1087,25 +960,20 @@ impl ConfigBuilder {
ScmPolicy::Static => false,
ScmPolicy::Dynamic => true,
};
self.scm_mask = Some(Some(CLibString::from(
params
.scm_mask
.to_vec_string(self.steps.unwrap_or_default()),
)));

self.cache = Some(cache);
self.cache = Some((cache, Some(params.scm_mask)));
self
}

pub fn taylor_seer_caching(&mut self) -> &mut Self {
let mut cache = Self::cache_init();
let (mut cache, mask) = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
self.cache = Some(cache);
self.cache = Some((cache, mask));
self
}

pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
self.db_cache_caching(params).cache.as_mut().unwrap().0.mode =
sd_cache_mode_t::SD_CACHE_CACHE_DIT;
self
}
}
Expand Down Expand Up @@ -1143,10 +1011,7 @@ impl From<Config> for ConfigBuilder {
.preview_mode(value.preview_mode)
.preview_noisy(value.preview_noisy)
.preview_interval(value.preview_interval)
.cache(value.cache);
if let Some(scm_mask) = value.scm_mask {
builder.scm_mask(scm_mask.clone());
}
.cache(value.cache.clone());
builder
}
}
Expand Down Expand Up @@ -1458,8 +1323,8 @@ fn gen_img_maybe_progress(
})
.collect();

let mut cache = config.cache;
if let Some(scm_mask) = &config.scm_mask {
let mut cache = config.cache.0;
if let Some(scm_mask) = &config.cache.1 {
cache.scm_mask = scm_mask.as_ptr();
}

Expand Down
Loading