diff --git a/.gitignore b/.gitignore index 990c832..ee15341 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,10 @@ Cargo.lock # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ bin/act +build/ *.png *.jpg *.jpeg .idea/ +.vscode/ \ No newline at end of file diff --git a/cli/src/main.rs b/cli/src/main.rs index feaa834..6fcb34d 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -4,7 +4,9 @@ use chrono::Local; use clap::{Parser, ValueEnum}; use diffusion_rs::{ - api::{PreviewType, gen_img}, + api::{ + DbCacheParamsBuilder, EasyCacheParamsBuilder, PreviewType, UCacheParamsBuilder, gen_img, + }, preset::{ AnimaWeight, ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight, @@ -24,6 +26,15 @@ macro_rules! clap_enum_variants { }}; } +#[derive(Clone, Debug, ValueEnum)] +enum CacheMode { + UCACHE, + EASYCACHE, + DBCACHE, + TAYLORSEER, + CACHEDIT, +} + #[derive(Clone, Debug, ValueEnum)] enum PreviewMode { Fast, @@ -68,6 +79,10 @@ struct Args { #[arg(short, long, default_value = "./")] output: PathBuf, + /// Caching methods accelerate diffusion inference + #[arg(long)] + cache: Option, + /// Enable preview #[arg(short, long, ignore_case = true)] preview: Option, @@ -107,57 +122,73 @@ fn main() { } println!(); - let (config, mut model_config) = PresetBuilder::default() - .preset(preset) - .prompt(&args.prompt) - .with_modifier(move |(mut config, mut model_config)| { - // atch request? - if args.batch > 1 { - config.batch_count(args.batch); - config.output(args.output); - } else { - config.output(file_name); - } - - if let Some(steps) = &args.steps { - config.steps(*steps); - } - - if args.random_seed { - config.seed(-1); - } - - if let Some(width) = args.width { - config.width(width); - } - - if let Some(height) = args.height { - config.height(height); - } - - if let Some(negative) = args.negative { - config.negative_prompt(negative); - } - - if args.low_vram { - model_config - .clip_on_cpu(true) - .vae_tiling(true) - .flash_attention(true) - .offload_params_to_cpu(true); - } - - match args.preview { - Some(PreviewMode::Fast) => config.preview_mode(PreviewType::PREVIEW_PROJ), - Some(PreviewMode::Accurate) => config.preview_mode(PreviewType::PREVIEW_VAE), - None => config.preview_mode(PreviewType::PREVIEW_NONE), - }; - config.preview_output(preview_filename); - - Ok((config, model_config)) - }) - .build() - .unwrap(); + let (config, mut model_config) = + PresetBuilder::default() + .preset(preset) + .prompt(&args.prompt) + .with_modifier(move |(mut config, mut model_config)| { + // atch request? + if args.batch > 1 { + config.batch_count(args.batch); + config.output(args.output); + } else { + config.output(file_name); + } + + if let Some(steps) = &args.steps { + config.steps(*steps); + } + + if args.random_seed { + config.seed(-1); + } + + if let Some(width) = args.width { + config.width(width); + } + + if let Some(height) = args.height { + config.height(height); + } + + if let Some(negative) = args.negative { + config.negative_prompt(negative); + } + + if args.low_vram { + model_config + .clip_on_cpu(true) + .vae_tiling(true) + .flash_attention(true) + .offload_params_to_cpu(true); + } + + match args.preview { + Some(PreviewMode::Fast) => config.preview_mode(PreviewType::PREVIEW_PROJ), + Some(PreviewMode::Accurate) => config.preview_mode(PreviewType::PREVIEW_VAE), + None => config.preview_mode(PreviewType::PREVIEW_NONE), + }; + + if let Some(cache) = args.cache { + match cache { + CacheMode::UCACHE => { + config.ucache_caching(UCacheParamsBuilder::default().build().unwrap()) + } + CacheMode::EASYCACHE => config + .easy_cache_caching(EasyCacheParamsBuilder::default().build().unwrap()), + CacheMode::DBCACHE => config + .db_cache_caching(DbCacheParamsBuilder::default().build().unwrap()), + CacheMode::TAYLORSEER => config.taylor_seer_caching(), + CacheMode::CACHEDIT => config + .cache_dit_caching(DbCacheParamsBuilder::default().build().unwrap()), + }; + } + config.preview_output(preview_filename); + + Ok((config, model_config)) + }) + .build() + .unwrap(); gen_img(&config, &mut model_config).unwrap(); println!(); diff --git a/src/api.rs b/src/api.rs index f8093bb..27ffa73 100644 --- a/src/api.rs +++ b/src/api.rs @@ -183,6 +183,7 @@ pub struct DbCacheParams { warmup: i32, /// Steps Computation Mask controls which steps can be cached + #[builder(default = "ScmPreset::default()")] scm_mask: ScmPreset, /// Scm Policy @@ -254,10 +255,10 @@ struct ScmPresetBins { impl ScmPresetBins { fn maybe_scale(&self) -> ScmPresetBins { - if self.steps == 28 || self.steps <= 0 { - return self.clone(); + if self.steps != 28 && self.steps > 0 { + return self.scale(); } - self.scale() + self.clone() } fn scale(&self) -> ScmPresetBins { @@ -295,7 +296,7 @@ impl ScmPresetBins { c_idx += 1; } if cache_idx < self.cache_bins.len() { - let cache_count = self.cache_bins[c_idx]; + let cache_count = self.cache_bins[cache_idx]; for _ in 0..cache_count { if mask.len() < self.steps as usize { mask.push(0); @@ -949,8 +950,8 @@ pub struct Config { #[builder(default = "Self::cache_init()", private)] cache: sd_cache_params_t, - #[builder(default = "CLibString::default()", private)] - scm_mask: CLibString, + #[builder(default = "None", private)] + scm_mask: Option, } impl ConfigBuilder { @@ -1033,12 +1034,11 @@ impl ConfigBuilder { ScmPolicy::Static => false, ScmPolicy::Dynamic => true, }; - self.scm_mask = Some(CLibString::from( + self.scm_mask = Some(Some(CLibString::from( params .scm_mask .to_vec_string(self.steps.unwrap_or_default()), - )); - cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr(); + ))); self.cache = Some(cache); self @@ -1060,9 +1060,6 @@ impl ConfigBuilder { impl From for ConfigBuilder { fn from(value: Config) -> Self { let mut builder = ConfigBuilder::default(); - let mut cache = value.cache; - let scm_mask = value.scm_mask.clone(); - cache.scm_mask = scm_mask.as_ptr(); builder .pm_id_images_dir(value.pm_id_images_dir) .init_img(value.init_img) @@ -1093,8 +1090,10 @@ impl From for ConfigBuilder { .preview_mode(value.preview_mode) .preview_noisy(value.preview_noisy) .preview_interval(value.preview_interval) - .cache(cache) - .scm_mask(scm_mask); + .cache(value.cache); + if let Some(scm_mask) = value.scm_mask { + builder.scm_mask(scm_mask.clone()); + } builder } } @@ -1406,6 +1405,11 @@ fn gen_img_maybe_progress( }) .collect(); + let mut cache = config.cache; + if let Some(scm_mask) = &config.scm_mask { + cache.scm_mask = scm_mask.as_ptr(); + } + let sd_img_gen_params = sd_img_gen_params_t { prompt: prompt.as_ptr(), negative_prompt: config.negative_prompt.as_ptr(), @@ -1426,7 +1430,7 @@ fn gen_img_maybe_progress( pm_params, vae_tiling_params, auto_resize_ref_image: config.disable_auto_resize_ref_image, - cache: config.cache, + cache, loras: loras.as_ptr(), lora_count: loras.len() as u32, };