Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use diffusion_rs::{
Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight,
Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight,
NitroSDVibrantWeight, OvisImageWeight, Preset, PresetBuilder, PresetDiscriminants,
QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, WeightType, ZImageTurboWeight,
QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
WeightType, ZImageTurboWeight,
},
util::set_hf_token,
};
Expand Down Expand Up @@ -306,7 +307,12 @@ fn get_preset(args: &Args) -> Preset {
.try_into()
.unwrap(),
),
PresetDiscriminants::SDXS512DreamShaper => Preset::SDXS512DreamShaper,
PresetDiscriminants::SDXS512DreamShaper => Preset::SDXS512DreamShaper(
args.weights
.unwrap_or_else(|| SDXS512DreamShaperWeight::default().into())
.try_into()
.unwrap(),
),
PresetDiscriminants::Flux2Klein4B => Preset::Flux2Klein4B(
args.weights
.unwrap_or_else(|| Flux2Klein4BWeight::default().into())
Expand Down
19 changes: 11 additions & 8 deletions src/preset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ use crate::{
Flux2Klein9BWeight(derive(Default)),
Flux2KleinBase9BWeight(derive(Default)),
AnimaWeight(derive(Default)),
Anima2Weight(derive(Default))
Anima2Weight(derive(Default)),
SDXS512DreamShaperWeight(derive(Default))
)]
#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
#[strum(ascii_case_insensitive)]
Expand All @@ -50,7 +51,8 @@ pub enum WeightType {
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
SSD1BWeight
SSD1BWeight,
SDXS512DreamShaperWeight(default)
)]
F16,
#[subenum(
Expand Down Expand Up @@ -103,7 +105,8 @@ pub enum WeightType {
Flux2KleinBase4BWeight(default),
Flux2Klein9BWeight,
AnimaWeight(default),
Anima2Weight(default)
Anima2Weight(default),
SDXS512DreamShaperWeight
)]
Q8_0,
Q8_1,
Expand Down Expand Up @@ -271,7 +274,7 @@ pub enum Preset {
/// Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD] and [crate::api::Scheduler::SMOOTHSTEP_SCHEDULER]. cfg_scale 1.0. 3 steps. Flash attention enabled. 1024x512. Vae-tiling enabled.
TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
/// cfg_scale 1.0. 1 steps. 512x512
SDXS512DreamShaper,
SDXS512DreamShaper(SDXS512DreamShaperWeight),
/// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.2-dev> providing a token via [crate::util::set_hf_token]
/// cfg scale 1.0. 4 steps. Flash attention enabled. Offload params to CPU enabled. 1024x1024. Vae-tiling enabled
Flux2Klein4B(Flux2Klein4BWeight),
Expand Down Expand Up @@ -321,7 +324,7 @@ impl Preset {
Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
Preset::SDXS512DreamShaper => sdxs512_dream_shaper(),
Preset::SDXS512DreamShaper(sd_type_t) => sdxs512_dream_shaper(sd_type_t),
Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t),
Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t),
Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t),
Expand Down Expand Up @@ -403,8 +406,8 @@ mod tests {
ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight,
Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight,
OvisImageWeight, QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
ZImageTurboWeight,
OvisImageWeight, QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight,
TwinFlowZImageTurboExpWeight, ZImageTurboWeight,
},
util::set_hf_token,
};
Expand Down Expand Up @@ -595,7 +598,7 @@ mod tests {
#[ignore]
#[test]
fn test_sdxs512_dream_shaper() {
run(Preset::SDXS512DreamShaper);
run(Preset::SDXS512DreamShaper(SDXS512DreamShaperWeight::Q8_0));
}

#[ignore]
Expand Down
17 changes: 13 additions & 4 deletions src/preset_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::{
Anima2Weight, AnimaWeight, ChromaRadianceWeight, ChromaWeight, ConfigsBuilder,
DiffInstructStarWeight, Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight,
Flux2Klein9BWeight, Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight,
NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight, QwenImageWeight, SSD1BWeight,
TwinFlowZImageTurboExpWeight, ZImageTurboWeight,
NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight, QwenImageWeight,
SDXS512DreamShaperWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, ZImageTurboWeight,
},
};
use diffusion_rs_sys::scheduler_t;
Expand Down Expand Up @@ -956,8 +956,17 @@ fn twinflow_z_image_turbo_weight(
Ok((model_path, llm_path))
}

pub fn sdxs512_dream_shaper() -> Result<ConfigsBuilder, ApiError> {
let model = download_file_hf_hub("akleine/sdxs-512", "sdxs.safetensors")?;
pub fn sdxs512_dream_shaper(sd_type: SDXS512DreamShaperWeight) -> Result<ConfigsBuilder, ApiError> {
let model = match sd_type {
SDXS512DreamShaperWeight::F16 => {
download_file_hf_hub("akleine/sdxs-512", "sdxs.safetensors")?
}
SDXS512DreamShaperWeight::Q8_0 => download_file_hf_hub(
"concedo/sdxs-512-tinySDdistilled-GGUF",
"sdxs-512-tinySDdistilled_Q8_0.gguf",
)?,
};

let mut config = ConfigBuilder::default();
let mut model_config = ModelConfigBuilder::default();

Expand Down
Loading