diff --git a/cli/src/main.rs b/cli/src/main.rs index 550fe9c..c475bd0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -13,7 +13,8 @@ use diffusion_rs::{ Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight, Preset, PresetBuilder, PresetDiscriminants, - QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, WeightType, ZImageTurboWeight, + QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, + WeightType, ZImageTurboWeight, }, util::set_hf_token, }; @@ -306,7 +307,12 @@ fn get_preset(args: &Args) -> Preset { .try_into() .unwrap(), ), - PresetDiscriminants::SDXS512DreamShaper => Preset::SDXS512DreamShaper, + PresetDiscriminants::SDXS512DreamShaper => Preset::SDXS512DreamShaper( + args.weights + .unwrap_or_else(|| SDXS512DreamShaperWeight::default().into()) + .try_into() + .unwrap(), + ), PresetDiscriminants::Flux2Klein4B => Preset::Flux2Klein4B( args.weights .unwrap_or_else(|| Flux2Klein4BWeight::default().into()) diff --git a/src/preset.rs b/src/preset.rs index 92c2c17..84f57e6 100644 --- a/src/preset.rs +++ b/src/preset.rs @@ -38,7 +38,8 @@ use crate::{ Flux2Klein9BWeight(derive(Default)), Flux2KleinBase9BWeight(derive(Default)), AnimaWeight(derive(Default)), - Anima2Weight(derive(Default)) + Anima2Weight(derive(Default)), + SDXS512DreamShaperWeight(derive(Default)) )] #[derive(Debug, Clone, Copy, EnumString, VariantNames)] #[strum(ascii_case_insensitive)] @@ -50,7 +51,8 @@ pub enum WeightType { NitroSDRealismWeight, NitroSDVibrantWeight, DiffInstructStarWeight, - SSD1BWeight + SSD1BWeight, + SDXS512DreamShaperWeight(default) )] F16, #[subenum( @@ -103,7 +105,8 @@ pub enum WeightType { Flux2KleinBase4BWeight(default), Flux2Klein9BWeight, AnimaWeight(default), - Anima2Weight(default) + Anima2Weight(default), + SDXS512DreamShaperWeight )] Q8_0, Q8_1, @@ -271,7 +274,7 @@ pub enum Preset { /// Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD] and [crate::api::Scheduler::SMOOTHSTEP_SCHEDULER]. cfg_scale 1.0. 3 steps. Flash attention enabled. 1024x512. Vae-tiling enabled. TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight), /// cfg_scale 1.0. 1 steps. 512x512 - SDXS512DreamShaper, + SDXS512DreamShaper(SDXS512DreamShaperWeight), /// Requires access rights to providing a token via [crate::util::set_hf_token] /// cfg scale 1.0. 4 steps. Flash attention enabled. Offload params to CPU enabled. 1024x1024. Vae-tiling enabled Flux2Klein4B(Flux2Klein4BWeight), @@ -321,7 +324,7 @@ impl Preset { Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t), Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(), Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t), - Preset::SDXS512DreamShaper => sdxs512_dream_shaper(), + Preset::SDXS512DreamShaper(sd_type_t) => sdxs512_dream_shaper(sd_type_t), Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t), Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t), Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t), @@ -403,8 +406,8 @@ mod tests { ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, - OvisImageWeight, QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, - ZImageTurboWeight, + OvisImageWeight, QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight, + TwinFlowZImageTurboExpWeight, ZImageTurboWeight, }, util::set_hf_token, }; @@ -595,7 +598,7 @@ mod tests { #[ignore] #[test] fn test_sdxs512_dream_shaper() { - run(Preset::SDXS512DreamShaper); + run(Preset::SDXS512DreamShaper(SDXS512DreamShaperWeight::Q8_0)); } #[ignore] diff --git a/src/preset_builder.rs b/src/preset_builder.rs index 30875cf..48eacbc 100644 --- a/src/preset_builder.rs +++ b/src/preset_builder.rs @@ -10,8 +10,8 @@ use crate::{ Anima2Weight, AnimaWeight, ChromaRadianceWeight, ChromaWeight, ConfigsBuilder, DiffInstructStarWeight, Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight, - NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight, QwenImageWeight, SSD1BWeight, - TwinFlowZImageTurboExpWeight, ZImageTurboWeight, + NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight, QwenImageWeight, + SDXS512DreamShaperWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, ZImageTurboWeight, }, }; use diffusion_rs_sys::scheduler_t; @@ -956,8 +956,17 @@ fn twinflow_z_image_turbo_weight( Ok((model_path, llm_path)) } -pub fn sdxs512_dream_shaper() -> Result { - let model = download_file_hf_hub("akleine/sdxs-512", "sdxs.safetensors")?; +pub fn sdxs512_dream_shaper(sd_type: SDXS512DreamShaperWeight) -> Result { + let model = match sd_type { + SDXS512DreamShaperWeight::F16 => { + download_file_hf_hub("akleine/sdxs-512", "sdxs.safetensors")? + } + SDXS512DreamShaperWeight::Q8_0 => download_file_hf_hub( + "concedo/sdxs-512-tinySDdistilled-GGUF", + "sdxs-512-tinySDdistilled_Q8_0.gguf", + )?, + }; + let mut config = ConfigBuilder::default(); let mut model_config = ModelConfigBuilder::default();