Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 55 additions & 28 deletions python/src/coreai_models/diffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ async def _async_export_diffusion(config: DiffusionExportConfig) -> dict[str, st
_save_tokenizer(config.hf_model_id, output_path, hf_pipe, overwrite=config.overwrite)

# 4. Write pipeline.json
_write_pipeline_json(hf_pipe, config.hf_model_id, pipeline_type, output_path)
_write_metadata_json(
hf_pipe, config.hf_model_id, pipeline_type, output_path, config.compression, results
)

# Summary
logger.info("=== Export Summary ===")
Expand Down Expand Up @@ -296,33 +298,63 @@ def _save_tokenizer(model_id: str, output_path: Path, hf_pipe: Any, overwrite: b


# ---------------------------------------------------------------------------
# pipeline.json
# metadata.json (v0.2 schema — aligned with LLM and segmenter bundles)
# ---------------------------------------------------------------------------

METADATA_VERSION = "0.2"

def _write_pipeline_json(
hf_pipe: Any, model_id: str, pipeline_type: str, output_path: Path

def _write_metadata_json(
hf_pipe: Any,
model_id: str,
pipeline_type: str,
output_path: Path,
compression: str,
exported_assets: dict[str, str],
) -> None:
"""Write pipeline.json with model metadata for the Swift pipeline."""
"""Write metadata.json with the v0.2 bundle schema for diffusion models."""
from datetime import datetime

if pipeline_type == "flux2":
pipeline_json = _build_flux2_pipeline_json(hf_pipe, model_id)
diffusion_config = _build_flux2_config(hf_pipe, model_id)
else:
pipeline_json = _build_sd_pipeline_json(hf_pipe, model_id, pipeline_type)
diffusion_config = _build_sd_config(hf_pipe, model_id, pipeline_type)

# Build assets map from exported component paths
assets: dict[str, str] = {}
for name, path_str in exported_assets.items():
assets[name] = Path(path_str).name

metadata = {
"metadata_version": METADATA_VERSION,
"kind": "diffusion",
"name": output_path.name,
"assets": assets,
"diffusion": diffusion_config,
"source": {
"model_definition": "torch",
"hf_model_id": model_id,
},
"compression": compression if compression != "none" else None,
Comment thread
kevchengcodes marked this conversation as resolved.
"compilation": {
"date": datetime.now().astimezone().isoformat(),
"targets": [],
},
}

json_path = output_path / "pipeline.json"
json_path = output_path / "metadata.json"
with open(json_path, "w") as f:
json.dump(pipeline_json, f, indent=2)
logger.info(f"Saved pipeline.json to {json_path}")
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata.json to {json_path}")


def _build_flux2_pipeline_json(hf_pipe: Any, model_id: str) -> dict:
def _build_flux2_config(hf_pipe: Any, model_id: str) -> dict:
vae_config = hf_pipe.vae.config
transformer_config = hf_pipe.transformer.config

vae_scale_power = len(vae_config.block_out_channels) - 1
vae_spatial_scale = 2**vae_scale_power
default_sample_size = getattr(transformer_config, "default_sample_size", 64)
# FLUX.2 uses 2x2 patchification, so actual image size is doubled
image_size = default_sample_size * vae_spatial_scale * 2

scaling_factor = getattr(vae_config, "scaling_factor", 1.0)
Expand All @@ -333,7 +365,6 @@ def _build_flux2_pipeline_json(hf_pipe: Any, model_id: str) -> dict:
rope_theta = getattr(transformer_config, "rope_theta", 2000.0)

return {
"model_id": model_id,
"type": "flux2",
"prediction_type": "flow_matching",
"encoder_scale_factor": scaling_factor,
Expand All @@ -349,7 +380,7 @@ def _build_flux2_pipeline_json(hf_pipe: Any, model_id: str) -> dict:
}


def _build_sd_pipeline_json(hf_pipe: Any, model_id: str, pipeline_type: str = "sd") -> dict:
def _build_sd_config(hf_pipe: Any, model_id: str, pipeline_type: str = "sd") -> dict:
scheduler_config = hf_pipe.scheduler.config
vae_config = hf_pipe.vae.config

Expand All @@ -364,8 +395,7 @@ def _build_sd_pipeline_json(hf_pipe: Any, model_id: str, pipeline_type: str = "s
vae_spatial_scale = 2**vae_scale_power
image_size = denoiser_config.sample_size * vae_spatial_scale

pipeline_json: dict[str, Any] = {
"model_id": model_id,
config: dict[str, Any] = {
"type": "stable-diffusion-3" if is_sd3 else "stable-diffusion",
"prediction_type": "flow" if is_sd3 else prediction_type,
"encoder_scale_factor": scaling_factor,
Expand All @@ -376,18 +406,15 @@ def _build_sd_pipeline_json(hf_pipe: Any, model_id: str, pipeline_type: str = "s
"default_steps": 28 if is_sd3 else 50,
}

if is_sd3:
# Autodetect also works for SD3 (recognizes MMDiT / text_encoder_2
# substrings); emitting explicit paths keeps pipeline.json self-
# documenting and guards against future detect() changes.
pipeline_json["components"] = {
"text_encoder": "TextEncoder.aimodel",
"text_encoder_2": "TextEncoder2.aimodel",
"unet": "MMDiT.aimodel",
"vae_decoder": "VAEDecoder.aimodel",
}

return pipeline_json
# Include scheduler defaults for reproducibility
config["scheduler"] = {
"training_steps": getattr(scheduler_config, "num_train_timesteps", 1000),
"beta_start": getattr(scheduler_config, "beta_start", 0.00085),
"beta_end": getattr(scheduler_config, "beta_end", 0.012),
"beta_schedule": getattr(scheduler_config, "beta_schedule", "scaled_linear"),
}

return config


# ---------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ public struct CoreAIDiffusionComponents: Sendable {
/// Errors during pipeline loading.
public enum PipelineLoadError: Error, LocalizedError {
case missingComponent(String)
case missingConfig(String)
case deprecatedFormat(String)
case configMismatch(field: String, expected: String, actual: String)

public var errorDescription: String? {
switch self {
case .missingComponent(let name):
return "Required component '\(name)' not found in model directory"
case .missingConfig(let detail):
return "Invalid bundle configuration: \(detail)"
case .deprecatedFormat(let message):
return message
case .configMismatch(let field, let expected, let actual):
return "Config mismatch for '\(field)': config says \(expected), model says \(actual)"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,27 @@ public struct PipelineDescriptor: Codable, Sendable {

/// Load or detect a pipeline descriptor from a model directory.
///
/// - `.auto`: reads `pipeline.json` if present, otherwise scans for known component filenames
/// - `.file`: reads from a specific config URL
/// - `.explicit`: uses the provided descriptor as-is
/// Priority:
/// 1. `metadata.json` (v0.2 schema with `kind: "diffusion"`)
/// 2. `pipeline.json` (deprecated — prints migration warning)
/// 3. Directory scan for known component filenames
///
/// Fields left nil by auto-detection are filled in later during `loadComponents(from:)`
/// by inspecting the actual model descriptors.
public static func resolve(at url: URL, config: ConfigSource = .auto) throws -> PipelineDescriptor {
switch config {
case .auto:
let configURL = url.appendingPathComponent("pipeline.json")
if FileManager.default.fileExists(atPath: configURL.path) {
return try load(from: configURL)
let metadataURL = url.appendingPathComponent("metadata.json")
if FileManager.default.fileExists(atPath: metadataURL.path) {
return try loadFromMetadata(at: metadataURL)
Comment thread
stikves marked this conversation as resolved.
}
let pipelineURL = url.appendingPathComponent("pipeline.json")
if FileManager.default.fileExists(atPath: pipelineURL.path) {
throw PipelineLoadError.deprecatedFormat(
"This bundle uses the legacy pipeline.json format which is no longer supported.\n"
+ "Please re-export with `coreai.diffusion.export` to produce metadata.json.\n"
+ "See: https://github.com/apple/coreai-models/issues/TBD"
)
}
return detect(at: url)
case .file(let configURL):
Expand All @@ -126,6 +135,33 @@ public struct PipelineDescriptor: Codable, Sendable {
}
}

/// Parse a metadata.json file (v0.2 schema) and extract the diffusion config.
public static func loadFromMetadata(at url: URL) throws -> PipelineDescriptor {
let data = try Data(contentsOf: url)
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] ?? [:]

guard let diffusion = json["diffusion"] as? [String: Any] else {
throw PipelineLoadError.missingConfig("metadata.json has no 'diffusion' block")
}
guard let assets = json["assets"] as? [String: String] else {
throw PipelineLoadError.missingConfig("metadata.json has no 'assets' map")
}

let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
let diffusionData = try JSONSerialization.data(withJSONObject: diffusion)
var descriptor = try decoder.decode(PipelineDescriptor.self, from: diffusionData)

// Map assets to component paths
descriptor.components.textEncoder = assets["text_encoder"]
descriptor.components.textEncoder2 = assets["text_encoder_2"]
descriptor.components.unet = assets["transformer"] ?? assets["unet"]
descriptor.components.vaeDecoder = assets["vae_decoder"]
descriptor.components.vaeEncoder = assets["vae_encoder"]

return descriptor
}

/// Parse a pipeline.json file.
/// Supports both the new format (with `components`) and the legacy format
/// (where component paths are inferred from the directory).
Expand Down
11 changes: 5 additions & 6 deletions swift/Tests/DiffusionPipelineTests/PipelineDescriptorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ struct PipelineDescriptorTests {
#expect(descriptor.components.unet == "Transformer.aimodel")
}

@Test("Resolve prefers pipeline.json over detection")
func resolvePrefersPipelineJSON() throws {
@Test("Resolve errors on legacy pipeline.json")
func resolveRejectsLegacyPipelineJSON() throws {
let dir = FileManager.default.temporaryDirectory.appendingPathComponent("sd_resolve_\(UUID())")
try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true)
defer { try? FileManager.default.removeItem(at: dir) }
Expand All @@ -175,11 +175,10 @@ struct PipelineDescriptorTests {
}
"""
try json.write(to: dir.appendingPathComponent("pipeline.json"), atomically: true, encoding: .utf8)
try "".write(to: dir.appendingPathComponent("TextEncoder.aimodel"), atomically: true, encoding: .utf8)

let descriptor = try PipelineDescriptor.resolve(at: dir)
#expect(descriptor.version == "2.0")
#expect(descriptor.components.textEncoder == "custom_encoder.aimodel")
#expect(throws: PipelineLoadError.self) {
_ = try PipelineDescriptor.resolve(at: dir)
}
}

@Test("Resolve with explicit config ignores directory")
Expand Down