From 5e2c7a7c5bf2b086e954a96bf5205a6ab2716356 Mon Sep 17 00:00:00 2001 From: rishab-partha Date: Tue, 18 Jun 2024 04:59:09 +0000 Subject: [PATCH 1/4] add latent logger + small fix to Image.Lanczos --- diffusion/callbacks/log_diffusion_images.py | 110 +++++++++++++++++--- scripts/batched-llava-caption.py | 2 +- 2 files changed, 99 insertions(+), 13 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index b60473eb..4cb1ac4c 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -3,6 +3,7 @@ """Logger for generated images.""" +import gc from math import ceil from typing import List, Optional, Tuple, Union @@ -10,6 +11,7 @@ from composer import Callback, Logger, State from composer.core import TimeUnit, get_precision_context from torch.nn.parallel import DistributedDataParallel +from transformers import AutoModel, AutoTokenizer, CLIPTextModel class LogDiffusionImages(Callback): @@ -45,7 +47,10 @@ def __init__(self, guidance_scale: float = 0.0, rescaled_guidance: Optional[float] = None, seed: Optional[int] = 1138, - use_table: bool = False): + use_table: bool = False, + text_encoder: Optional[str] = None, + clip_encoder: Optional[str] = None, + cache_dir: Optional[str] = '/tmp/hf_files'): self.prompts = prompts self.size = (size, size) if isinstance(size, int) else size self.num_inference_steps = num_inference_steps @@ -53,6 +58,7 @@ def __init__(self, self.rescaled_guidance = rescaled_guidance self.seed = seed self.use_table = use_table + self.cache_dir = cache_dir # Batch prompts batch_size = len(prompts) if batch_size is None else batch_size @@ -62,6 +68,66 @@ def __init__(self, start, end = i * batch_size, (i + 1) * batch_size self.batched_prompts.append(prompts[start:end]) + if text_encoder is not None and clip_encoder is None or text_encoder is None and clip_encoder is not None: + raise ValueError('Cannot specify only one of text encoder and CLIP encoder.') + + self.precomputed_latents = False + self.batched_latents = [] + if text_encoder: + self.precomputed_latents = True + t5_tokenizer = AutoTokenizer.from_pretrained(text_encoder, cache_dir=self.cache_dir, local_files_only=True) + clip_tokenizer = AutoTokenizer.from_pretrained(clip_encoder, + subfolder='tokenizer', + cache_dir=self.cache_dir, + local_files_only=True) + + t5_model = AutoModel.from_pretrained(text_encoder, + torch_dtype=torch.float16, + cache_dir=self.cache_dir, + local_files_only=True).encoder.cuda().eval() + clip_model = CLIPTextModel.from_pretrained(clip_encoder, + subfolder='text_encoder', + torch_dtype=torch.float16, + cache_dir=self.cache_dir, + local_files_only=True).cuda().eval() + + for batch in self.batched_prompts: + latent_batch = {} + tokenized_t5 = t5_tokenizer(batch, + padding='max_length', + max_length=t5_tokenizer.model.max_length, + truncation=True, + return_tensors='pt') + t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda() + t5_ids = tokenized_t5['input_ids'].cuda() + t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu() + t5_attention_mask = t5_attention_mask.cpu().to(torch.long) + + tokenized_clip = clip_tokenizer(batch, + padding='max_length', + max_length=t5_tokenizer.model.max_length, + truncation=True, + return_tensors='pt') + clip_attention_mask = tokenized_clip['attention_mask'].cuda() + clip_ids = tokenized_clip['input_ids'].cuda() + clip_outputs = clip_model(input_ids=clip_ids, + attention_mask=clip_attention_mask, + output_hidden_states=True) + clip_latents = clip_outputs.hidden_states[-2].cpu() + clip_pooled = clip_outputs[-1].cpu() + clip_attention_mask = clip_attention_mask.cpu().to(torch.long) + + latent_batch['T5_LATENTS'] = t5_latents + latent_batch['CLIP_LATENTS'] = clip_latents + latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1) + latent_batch['CLIP_POOLED'] = clip_pooled + self.batched_latents.append(latent_batch) + + del t5_model + del clip_model + gc.collect() + torch.cuda.empty_cache() + def eval_start(self, state: State, logger: Logger): # Get the model object if it has been wrapped by DDP to access the image generation function. if isinstance(state.model, DistributedDataParallel): @@ -72,17 +138,37 @@ def eval_start(self, state: State, logger: Logger): # Generate images with get_precision_context(state.precision): all_gen_images = [] - for batch in self.batched_prompts: - gen_images = model.generate( - prompt=batch, # type: ignore - height=self.size[0], - width=self.size[1], - guidance_scale=self.guidance_scale, - rescaled_guidance=self.rescaled_guidance, - progress_bar=False, - num_inference_steps=self.num_inference_steps, - seed=self.seed) - all_gen_images.append(gen_images) + if self.precomputed_latents: + for batch in self.batched_latents: + pooled_prompt = batch['CLIP_POOLED'].cuda() + prompt_mask = batch['ATTENTION_MASK'].cuda() + t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda()) + clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda()) + prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1) + + gen_images = model.generate(prompt_embeds=prompt_embeds, + pooled_prompt=pooled_prompt, + prompt_mask=prompt_mask, + height=self.size[0], + width=self.size[1], + guidance_scale=self.guidance_scale, + rescaled_guidance=self.rescaled_guidance, + progress_bar=False, + num_inference_steps=self.num_inference_steps, + seed=self.seed) + all_gen_images.append(gen_images) + else: + for batch in self.batched_prompts: + gen_images = model.generate( + prompt=batch, # type: ignore + height=self.size[0], + width=self.size[1], + guidance_scale=self.guidance_scale, + rescaled_guidance=self.rescaled_guidance, + progress_bar=False, + num_inference_steps=self.num_inference_steps, + seed=self.seed) + all_gen_images.append(gen_images) gen_images = torch.cat(all_gen_images) # Log images to wandb diff --git a/scripts/batched-llava-caption.py b/scripts/batched-llava-caption.py index 347a4db9..558e568c 100644 --- a/scripts/batched-llava-caption.py +++ b/scripts/batched-llava-caption.py @@ -164,7 +164,7 @@ def resize_and_pad(self, image: Image.Image) -> Image.Image: resize_width = round(resize_height * aspect_ratio) else: raise ValueError('Invalid image dimensions') - resized_image = image.resize((resize_width, resize_height), Image.Resampling.LANCZOS) + resized_image = image.resize((resize_width, resize_height), Image.LANCZOS) # Calculate padding pad_width_left = (self.width - resize_width) // 2 From 4ee33961b7f5e6a9223b409f470e52f62aac6017 Mon Sep 17 00:00:00 2001 From: rishab-partha Date: Tue, 18 Jun 2024 07:44:55 +0000 Subject: [PATCH 2/4] revert PIL change --- scripts/batched-llava-caption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/batched-llava-caption.py b/scripts/batched-llava-caption.py index 558e568c..347a4db9 100644 --- a/scripts/batched-llava-caption.py +++ b/scripts/batched-llava-caption.py @@ -164,7 +164,7 @@ def resize_and_pad(self, image: Image.Image) -> Image.Image: resize_width = round(resize_height * aspect_ratio) else: raise ValueError('Invalid image dimensions') - resized_image = image.resize((resize_width, resize_height), Image.LANCZOS) + resized_image = image.resize((resize_width, resize_height), Image.Resampling.LANCZOS) # Calculate padding pad_width_left = (self.width - resize_width) // 2 From 5f809ea361491d902d4a84f072f960c340c58724 Mon Sep 17 00:00:00 2001 From: rishab-partha Date: Mon, 29 Jul 2024 16:14:46 -0700 Subject: [PATCH 3/4] fixes? --- diffusion/callbacks/log_diffusion_images.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 4cb1ac4c..402ffd5d 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -37,6 +37,9 @@ class LogDiffusionImages(Callback): seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation. Default: ``1138``. use_table (bool): Whether to make a table of the images or not. Default: ``False``. + t5_encoder (str, optional): path to the T5 encoder to as a second text encoder. + clip_encoder (str, optional): path to the CLIP encoder as the first text encoder. + cache_dir: (str, optional): path for HF to cache files while downloading model """ def __init__(self, @@ -48,7 +51,7 @@ def __init__(self, rescaled_guidance: Optional[float] = None, seed: Optional[int] = 1138, use_table: bool = False, - text_encoder: Optional[str] = None, + t5_encoder: Optional[str] = None, clip_encoder: Optional[str] = None, cache_dir: Optional[str] = '/tmp/hf_files'): self.prompts = prompts @@ -68,20 +71,20 @@ def __init__(self, start, end = i * batch_size, (i + 1) * batch_size self.batched_prompts.append(prompts[start:end]) - if text_encoder is not None and clip_encoder is None or text_encoder is None and clip_encoder is not None: + if t5_encoder is not None and clip_encoder is None or t5_encoder is None and clip_encoder is not None: raise ValueError('Cannot specify only one of text encoder and CLIP encoder.') self.precomputed_latents = False self.batched_latents = [] - if text_encoder: + if t5_encoder: self.precomputed_latents = True - t5_tokenizer = AutoTokenizer.from_pretrained(text_encoder, cache_dir=self.cache_dir, local_files_only=True) + t5_tokenizer = AutoTokenizer.from_pretrained(t5_encoder, cache_dir=self.cache_dir, local_files_only=True) clip_tokenizer = AutoTokenizer.from_pretrained(clip_encoder, subfolder='tokenizer', cache_dir=self.cache_dir, local_files_only=True) - t5_model = AutoModel.from_pretrained(text_encoder, + t5_model = AutoModel.from_pretrained(t5_encoder, torch_dtype=torch.float16, cache_dir=self.cache_dir, local_files_only=True).encoder.cuda().eval() @@ -114,7 +117,7 @@ def __init__(self, attention_mask=clip_attention_mask, output_hidden_states=True) clip_latents = clip_outputs.hidden_states[-2].cpu() - clip_pooled = clip_outputs[-1].cpu() + clip_pooled = clip_outputs[1].cpu() clip_attention_mask = clip_attention_mask.cpu().to(torch.long) latent_batch['T5_LATENTS'] = t5_latents From 79970a24036208728e4761e99995fa1d9702a7d7 Mon Sep 17 00:00:00 2001 From: rishab-partha Date: Fri, 9 Aug 2024 17:24:03 -0700 Subject: [PATCH 4/4] controlnet v1 --- diffusion/callbacks/__init__.py | 2 + .../callbacks/assign_controlnet_weight.py | 49 ++ diffusion/models/__init__.py | 4 +- diffusion/models/controlnet.py | 617 ++++++++++++++++++ diffusion/models/models.py | 476 +++++++++++++- 5 files changed, 1146 insertions(+), 2 deletions(-) create mode 100644 diffusion/callbacks/assign_controlnet_weight.py create mode 100644 diffusion/models/controlnet.py diff --git a/diffusion/callbacks/__init__.py b/diffusion/callbacks/__init__.py index 48727c72..d9b35885 100644 --- a/diffusion/callbacks/__init__.py +++ b/diffusion/callbacks/__init__.py @@ -7,8 +7,10 @@ from diffusion.callbacks.log_latent_statistics import LogLatentStatistics from diffusion.callbacks.nan_catcher import NaNCatcher from diffusion.callbacks.scheduled_garbage_collector import ScheduledGarbageCollector +from diffusion.callbacks.assign_controlnet_weight import AssignControlNet __all__ = [ + 'AssignControlNet', 'LogAutoencoderImages', 'LogDiffusionImages', 'LogLatentStatistics', diff --git a/diffusion/callbacks/assign_controlnet_weight.py b/diffusion/callbacks/assign_controlnet_weight.py new file mode 100644 index 00000000..76dc8a74 --- /dev/null +++ b/diffusion/callbacks/assign_controlnet_weight.py @@ -0,0 +1,49 @@ +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from composer import Callback, Logger, State +from composer.core import get_precision_context +from torch.nn.parallel import DistributedDataParallel +from diffusers import ControlNetModel, UNet2DConditionModel + +class AssignControlNet(Callback): + """Assigns Controlnet weights to the controlnet from the Unet after composer loads the checkpoint + + Args: + use_fsdp: whether or not the model is FSDP wrapped + """ + + def __init__(self, use_fsdp): + self.use_fsdp = use_fsdp + + def process_controlnet(self, controlnet: ControlNetModel, unet: UNet2DConditionModel): + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + if hasattr(controlnet, "add_embedding"): + controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + def after_load(self, state: State, logger: Logger): + # Get the model object if it has been wrapped by DDP to access the image generation function. + if isinstance(state.model, DistributedDataParallel): + model = state.model.module + else: + model = state.model + + # Load checkpoint + if model.load_controlnet_from_composer: + with get_precision_context(state.precision): + if self.use_fsdp: + with FSDP.summon_full_params(model.unet, recurse = True, writeback = False): + with FSDP.summon_full_params(model.controlnet, recurse = True, writeback = True): + self.process_controlnet(model.controlnet, model.unet) + + else: + self.process_controlnet(model.controlnet, model.unet) + \ No newline at end of file diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 67f61ac2..d443f677 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -4,7 +4,7 @@ """Diffusion models.""" from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion, - discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl) + discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, stable_diffusion_2_controlnet, stable_diffusion_xl_controlnet) from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion @@ -16,5 +16,7 @@ 'PixelDiffusion', 'stable_diffusion_2', 'stable_diffusion_xl', + 'stable_diffusion_2_controlnet', + 'stable_diffusion_xl_controlnet', 'StableDiffusion', ] diff --git a/diffusion/models/controlnet.py b/diffusion/models/controlnet.py new file mode 100644 index 00000000..79409fbd --- /dev/null +++ b/diffusion/models/controlnet.py @@ -0,0 +1,617 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""ControlNet Diffusion models.""" + +from contextlib import nullcontext +from typing import List, Optional, Tuple, Union +import PIL +import torch +import torch.nn.functional as F +from composer.models import ComposerModel +from composer.utils import dist +from scipy.stats import qmc +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm +from diffusers.image_processor import VaeImageProcessor + +class ControlNet(ComposerModel): + def __init__(self, + unet, + vae, + text_encoder, + tokenizer, + controlnet, + noise_scheduler, + inference_noise_scheduler, + loss_fn=F.mse_loss, + prediction_type: str = 'epsilon', + latent_mean: Optional[Tuple[float]] = None, + latent_std: Optional[Tuple[float]] = None, + downsample_factor: int = 8, + offset_noise: Optional[float] = None, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + image_key: str = 'image', + text_key: str = 'captions', + image_latents_key: str = 'image_latents', + text_latents_key: str = 'caption_latents', + controlnet_images_key: str = 'controlnet_images', + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = False, + mask_pad_tokens: bool = False, + fsdp: bool = False, + sdxl: bool = False, + load_controlnet_from_composer: bool = False): + super().__init__() + self.unet = unet + self.vae = vae + self.load_controlnet_from_composer = load_controlnet_from_composer + self.controlnet = controlnet + self.noise_scheduler = noise_scheduler + self.loss_fn = loss_fn + self.prediction_type = prediction_type.lower() + if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: + raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') + self.downsample_factor = downsample_factor + self.offset_noise = offset_noise + self.quasirandomness = quasirandomness + self.train_seed = train_seed + self.val_seed = val_seed + self.image_key = image_key + self.controlnet_images_key = controlnet_images_key + self.image_latents_key = image_latents_key + self.precomputed_latents = precomputed_latents + self.mask_pad_tokens = mask_pad_tokens + self.sdxl = sdxl + if latent_mean is None: + self.latent_mean = 4 * (0.0) + if latent_std is None: + self.latent_std = 4 * (1 / 0.13025,) if self.sdxl else 4 * (1 / 0.18215,) + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.train_metrics = train_metrics if train_metrics is not None else [MeanSquaredError()] + self.val_metrics = val_metrics if val_metrics is not None else [MeanSquaredError()] + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.inference_scheduler = inference_noise_scheduler + self.text_key = text_key + self.text_latents_key = text_latents_key + self.encode_latents_in_fp16 = encode_latents_in_fp16 + self.mask_pad_tokens = mask_pad_tokens + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.downsample_factor, do_convert_rgb=True, do_normalize=False + ) + # freeze text_encoder during diffusion training + self.text_encoder.requires_grad_(False) + self.vae.requires_grad_(False) + self.unet.requires_grad_(False) + self.controlnet.train() + if self.encode_latents_in_fp16: + self.text_encoder = self.text_encoder.half() + self.vae = self.vae.half() + if fsdp: + # only wrap models in the diffusion process + self.text_encoder._fsdp_wrap = False + self.vae._fsdp_wrap = False + self.unet._fsdp_wrap = True + self.controlnet._fsdp_wrap = True + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + if self.quasirandomness: + self.sobol = qmc.Sobol(d=1, scramble=True, seed=self.train_seed) + + def _apply(self, fn): + super(ControlNet, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + + def _generate_timesteps(self, latents: torch.Tensor): + if self.quasirandomness: + # Generate a quasirandom sequence of timesteps equal to the global batch size + global_batch_size = latents.shape[0] * dist.get_world_size() + sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) + timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() + timesteps = torch.floor(timesteps).long() + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + timesteps = torch.randint(0, + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) + return timesteps + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def forward(self, batch): + latents, text_embeds, text_pooled_embeds, attention_mask, encoder_attention_mask = None, None, None, None, None + if 'attention_mask' in batch and self.mask_pad_tokens: + attention_mask = batch['attention_mask'] # mask for text encoders + encoder_attention_mask = _create_unet_attention_mask(attention_mask) # text mask for U-Net + + # Use latents if specified and available. When specified, they might not exist during eval + if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed latents') + latents, text_embeds = batch[self.image_latents_key], batch[self.text_latents_key] + else: + inputs, conditionings = batch[self.image_key], batch[self.text_key] + + # If encode_latents_in_fp16, disable autocast context as models are in fp16 + c = torch.cuda.amp.autocast(enabled=False) if self.encode_latents_in_fp16 else nullcontext() # type: ignore + with c: + # Encode the images to the latent space. + if self.encode_latents_in_fp16: + latents = self.vae.encode(inputs.half())['latent_dist'].sample().data + else: + latents = self.vae.encode(inputs)['latent_dist'].sample().data + # Encode tokenized prompt into embedded text and pooled text embeddings + text_encoder_out = self.text_encoder(conditionings, attention_mask=attention_mask) + text_embeds = text_encoder_out[0] + if self.sdxl: + if len(text_encoder_out) <= 1: + raise RuntimeError('SDXL requires text encoder output to include a pooled text embedding') + text_pooled_embeds = text_encoder_out[1] + + # Scale the latents + latents = (latents - self.latent_mean) / self.latent_std + + # Zero dropped captions if needed + if 'drop_caption_mask' in batch.keys(): + text_embeds *= batch['drop_caption_mask'].view(-1, 1, 1) + if text_pooled_embeds is not None: + text_pooled_embeds *= batch['drop_caption_mask'].view(-1, 1) + + # Sample the diffusion timesteps + timesteps = self._generate_timesteps(latents) + # Add noise to the inputs (forward diffusion) + noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) + if self.offset_noise is not None: + offset_noise = torch.randn(latents.shape[0], + latents.shape[1], + 1, + 1, + device=noise.device, + generator=self.rng_generator) + noise += self.offset_noise * offset_noise + noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + c = torch.cuda.amp.autocast(enabled=False) if self.encode_latents_in_fp16 else nullcontext() # type: ignore + with c: + # Encode the images to the latent space. + if self.encode_latents_in_fp16: + controlnet_image = batch[self.controlnet_images_key].half() + else: + controlnet_image = batch[self.controlnet_images_key] + + down_block_sample, mid_block_sample = self.controlnet( + noised_latents, + timesteps, + text_embeds, + controlnet_image, + attention_mask=encoder_attention_mask, + return_dict=False + ) + # Generate the targets + if self.prediction_type == 'epsilon': + targets = noise + elif self.prediction_type == 'sample': + targets = latents + elif self.prediction_type == 'v_prediction': + targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl: + add_time_ids = torch.cat( + [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) + added_cond_kwargs = {'text_embeds': text_pooled_embeds, 'time_ids': add_time_ids} + + # Forward through the model + return self.unet(noised_latents, + timesteps, + text_embeds, + encoder_attention_mask=encoder_attention_mask, + down_block_additional_residuals = down_block_sample, + mid_block_additional_residual = mid_block_sample, + added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps + + def loss(self, outputs, batch): + """Loss between unet output and added noise, typically mse.""" + return self.loss_fn(outputs[0], outputs[1]) + + def eval_forward(self, batch, outputs=None): + """For stable diffusion, eval forward computes unet outputs as well as some samples.""" + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics = self.train_metrics + else: + metrics = self.val_metrics + metrics_dict = {metric.__class__.__name__: metric for metric in metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + metric.update(outputs[0], outputs[1]) + + @torch.no_grad() + def generate( + self, + image: Union[PIL.Image.Image, torch.Tensor, List[PIL.Image.Image], List[torch.Tensor]] = None, + prompt: Optional[list] = None, + negative_prompt: Optional[list] = None, + tokenized_prompts: Optional[torch.LongTensor] = None, + tokenized_negative_prompts: Optional[torch.LongTensor] = None, + tokenized_prompts_pad_mask: Optional[torch.LongTensor] = None, + tokenized_negative_prompts_pad_mask: Optional[torch.LongTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 3.0, + rescaled_guidance: Optional[float] = None, + num_images_per_prompt: Optional[int] = 1, + seed: Optional[int] = None, + progress_bar: Optional[bool] = True, + zero_out_negative_prompt: bool = True, + crop_params: Optional[torch.Tensor] = None, + input_size_params: Optional[torch.Tensor] = None, + controlnet_conditioning_scale: Optional[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: Optional[float] = 0.0, + control_guidance_end: Optional[float] = 1.0, + ): + """Generates image from noise. + + Performs the backward diffusion process, each inference step takes + one forward pass through the unet. + + Args: + prompt (str or List[str]): The prompt or prompts to guide the image generation. + image (PIL.Image.Image, torch.Tensor, List[PIL.Image.Image], List[torch.Tensor): the images + for controlnet guidance. + negative_prompt (str or List[str]): The prompt or prompts to guide the + image generation away from. Ignored when not using guidance + (i.e., ignored if guidance_scale is less than 1). + Must be the same length as list of prompts. Default: `None`. + tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If SDXL, this will be a tensor of size [B, 2, max_length], + otherwise will be of shape [B, max_length]. Default: `None`. + tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative + prompts instead of string prompts. Default: `None`. + tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for + pre-tokenized prompts. Default `None`. + tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for + pre-tokenized negative prompts. Default `None`. + prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If both prompt and prompt_embeds + are passed, prompt_embeds will be used. Default: `None`. + negative_prompt_embeds (torch.FloatTensor): Optionally pass pre-embedded negative + prompts instead of string negative prompts. If both negative_prompt and + negative_prompt_embeds are passed, prompt_embeds will be used. Default: `None`. + height (int, optional): The height in pixels of the generated image. + Default: `self.unet.config.sample_size * 8)`. + width (int, optional): The width in pixels of the generated image. + Default: `self.unet.config.sample_size * 8)`. + num_inference_steps (int): The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense + of slower inference. Default: `50`. + guidance_scale (float): Guidance scale as defined in + Classifier-Free Diffusion Guidance. guidance_scale is defined as w of equation + 2. of Imagen Paper. Guidance scale is enabled by setting guidance_scale > 1. + Higher guidance scale encourages to generate images that are closely linked + to the text prompt, usually at the expense of lower image quality. + Default: `3.0`. + rescaled_guidance (float, optional): Rescaled guidance scale. If not specified, rescaled guidance will + not be used. Default: `None`. + num_images_per_prompt (int): The number of images to generate per prompt. + Default: `1`. + progress_bar (bool): Whether to use the tqdm progress bar during generation. + Default: `True`. + seed (int): Random seed to use for generation. Set a seed for reproducible generation. + Default: `None`. + zero_out_negative_prompt (bool): Whether or not to zero out negative prompt if it is + an empty string. Default: `True`. + crop_params (torch.FloatTensor of size [Bx2], optional): Crop parameters to use + when generating images with SDXL. Default: `None`. + input_size_params (torch.FloatTensor of size [Bx2], optional): Size parameters + (representing original size of input image) to use when generating images with SDXL. + Default: `None`. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + """ + _check_prompt_given(prompt, tokenized_prompts, prompt_embeds) + _check_prompt_lenths(prompt, negative_prompt) + _check_prompt_lenths(tokenized_prompts, tokenized_negative_prompts) + _check_prompt_lenths(prompt_embeds, negative_prompt_embeds) + + guess_mode = guess_mode or self.controlnet.config.global_pool_conditions + # Create rng for the generation + device = self.vae.device + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) # type: ignore + + height = height or self.unet.config.sample_size * self.downsample_factor + width = width or self.unet.config.sample_size * self.downsample_factor + assert height is not None # for type checking + assert width is not None # for type checking + + do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore + + text_embeddings, pooled_text_embeddings, pad_attn_mask = self._prepare_text_embeddings( + prompt, tokenized_prompts, tokenized_prompts_pad_mask, prompt_embeds, num_images_per_prompt) + batch_size = len(text_embeddings) # len prompts * num_images_per_prompt + # classifier free guidance + negative prompts + # negative prompt is given in place of the unconditional input in classifier free guidance + pooled_embeddings, encoder_attn_mask = pooled_text_embeddings, pad_attn_mask + if do_classifier_free_guidance: + if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt: + # Negative prompt is empty and we want to zero it out + unconditional_embeddings = torch.zeros_like(text_embeddings) + pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None + uncond_pad_attn_mask = torch.zeros_like(pad_attn_mask) if pad_attn_mask is not None else None + else: + if not negative_prompt: + negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore + unconditional_embeddings, pooled_unconditional_embeddings, uncond_pad_attn_mask = self._prepare_text_embeddings( + negative_prompt, tokenized_negative_prompts, tokenized_negative_prompts_pad_mask, + negative_prompt_embeds, num_images_per_prompt) + + # concat uncond + prompt + text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]) + if self.sdxl: + pooled_embeddings = torch.cat([pooled_unconditional_embeddings, pooled_text_embeddings]) # type: ignore + if pad_attn_mask is not None: + encoder_attn_mask = torch.cat([uncond_pad_attn_mask, pad_attn_mask]) # type: ignore + else: + if self.sdxl: + pooled_embeddings = pooled_text_embeddings + + image = self._prepare_image(image, width, height, batch_size, num_images_per_prompt, device, self.controlnet.dtype, do_classifier_free_guidance, guess_mode) + height, width = image.shape[-2:] + + # prepare for diffusion generation process + latents = torch.randn( + (batch_size, self.unet.config.in_channels, height // self.downsample_factor, + width // self.downsample_factor), + device=device, + dtype=self.unet.dtype, + generator=rng_generator, + ) + + self.inference_scheduler.set_timesteps(num_inference_steps) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.inference_scheduler.init_noise_sigma + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl and pooled_embeddings is not None: + if crop_params is None: + crop_params = torch.zeros((batch_size, 2), dtype=text_embeddings.dtype) + if input_size_params is None: + input_size_params = torch.tensor([width, height], dtype=text_embeddings.dtype).repeat(batch_size, 1) + output_size_params = torch.tensor([width, height], dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if do_classifier_free_guidance: + crop_params = torch.cat([crop_params, crop_params]) + input_size_params = torch.cat([input_size_params, input_size_params]) + output_size_params = torch.cat([output_size_params, output_size_params]) + + add_time_ids = torch.cat([input_size_params, crop_params, output_size_params], dim=1).to(device) + added_cond_kwargs = {'text_embeds': pooled_embeddings, 'time_ids': add_time_ids} + + use_controlnet = [] + for i in range(num_inference_steps): + use_controlnet.append(1.0 - float(i / num_inference_steps < control_guidance_start or (i + 1) / num_inference_steps > control_guidance_end)) + index = 0 + # backward diffusion process + for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + + latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": pooled_embeddings.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + cond_scale = controlnet_conditioning_scale*use_controlnet[index] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # Model prediction + pred = self.unet(latent_model_input, + t, + encoder_hidden_states=text_embeddings, + encoder_attention_mask=encoder_attn_mask, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs).sample + + if do_classifier_free_guidance: + # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_text, dim=(1, 2, 3), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2, 3), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample + index += 1 + + # We now use the vae to decode the generated latents back into the image. + # scale and decode the image latents with vae + latents = latents * self.latent_std + self.latent_mean + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image.detach() # (batch*num_images_per_prompt, channel, h, w) + + def _prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def _prepare_text_embeddings(self, prompt, tokenized_prompts, tokenized_pad_mask, prompt_embeds, + num_images_per_prompt): + """Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt.""" + device = self.text_encoder.device + pooled_text_embeddings = None + if prompt_embeds is None: + if tokenized_prompts is None: + tokenized_out = self.tokenizer(prompt, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_prompts = tokenized_out['input_ids'] + if self.mask_pad_tokens: + tokenized_pad_mask = tokenized_out['attention_mask'] + else: + tokenized_pad_mask = None + if tokenized_pad_mask is not None: + tokenized_pad_mask = tokenized_pad_mask.to(device) + text_encoder_out = self.text_encoder(tokenized_prompts.to(device), attention_mask=tokenized_pad_mask) + prompt_embeds = text_encoder_out[0] + if self.sdxl: + if len(text_encoder_out) <= 1: + raise RuntimeError('SDXL requires text encoder output to include a pooled text embedding') + pooled_text_embeddings = text_encoder_out[1] + else: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed embeddings') + + # duplicate text embeddings for each generation per prompt + prompt_embeds = _duplicate_tensor(prompt_embeds, num_images_per_prompt) + + if not self.mask_pad_tokens: + tokenized_pad_mask = None + + if tokenized_pad_mask is not None: + tokenized_pad_mask = _create_unet_attention_mask(tokenized_pad_mask) + tokenized_pad_mask = _duplicate_tensor(tokenized_pad_mask, num_images_per_prompt) + + if self.sdxl and pooled_text_embeddings is not None: + pooled_text_embeddings = _duplicate_tensor(pooled_text_embeddings, num_images_per_prompt) + return prompt_embeds, pooled_text_embeddings, tokenized_pad_mask + +def _check_prompt_lenths(prompt, negative_prompt): + if prompt is None and negative_prompt is None: + return + batch_size = 1 if isinstance(prompt, str) else len(prompt) + if negative_prompt: + negative_prompt_bs = 1 if isinstance(negative_prompt, str) else len(negative_prompt) + if negative_prompt_bs != batch_size: + raise ValueError('len(prompts) and len(negative_prompts) must be the same. \ + A negative prompt must be provided for each given prompt.') + + +def _check_prompt_given(prompt, tokenized_prompts, prompt_embeds): + if prompt is None and tokenized_prompts is None and prompt_embeds is None: + raise ValueError('Must provide one of `prompt`, `tokenized_prompts`, or `prompt_embeds`') + + +def _create_unet_attention_mask(attention_mask): + """Takes the union of multiple attention masks if given more than one mask.""" + if len(attention_mask.shape) == 2: + return attention_mask + elif len(attention_mask.shape) == 3: + encoder_attention_mask = attention_mask[:, 0] + for i in range(1, attention_mask.shape[1]): + encoder_attention_mask |= attention_mask[:, i] + return encoder_attention_mask + else: + raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_mask.shape}') + + +def _duplicate_tensor(tensor, num_images_per_prompt): + """Duplicate tensor for multiple generations from a single prompt.""" + batch_size, seq_len = tensor.shape[:2] + tensor = tensor.repeat(1, num_images_per_prompt, *[ + 1, + ] * len(tensor.shape[2:])) + return tensor.view(batch_size * num_images_per_prompt, seq_len, *[ + -1, + ] * len(tensor.shape[2:])) \ No newline at end of file diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 562a08fb..cfe3965c 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -8,7 +8,7 @@ import torch from composer.devices import DeviceGPU -from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel, ControlNetModel from torchmetrics import MeanSquaredError from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig @@ -18,6 +18,7 @@ from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer +from diffusion.models.controlnet import ControlNet from diffusion.schedulers.schedulers import ContinuousTimeScheduler try: @@ -218,6 +219,212 @@ def stable_diffusion_2( return model +def stable_diffusion_2_controlnet( + model_name: str = 'stabilityai/stable-diffusion-2-base', + pretrained: bool = True, + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + prediction_type: str = 'epsilon', + controlnet_from_composer_unet: bool = False, + controlnet: Optional[str] = None, + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 5.489980785067252, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + offset_noise: Optional[float] = None, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = True, + mask_pad_tokens: bool = False, + fsdp: bool = True, + clip_qkv: Optional[float] = None, + use_xformers: bool = True, +): + """Stable diffusion v2 training setup. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. + + Args: + model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. + pretrained (bool): Whether to load pretrained weights. Defaults to True. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (float, list, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, list, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.18215`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not + be used. Default `None`. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False. + fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + precision = torch.float16 if encode_latents_in_fp16 else None + # Make the text encoder + text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=precision) + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + # Use the pretrained vae + downsample_factor = 8 + vae = AutoencoderKL.from_pretrained(model_name, subfolder='vae', torch_dtype=precision) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(model_name, subfolder='unet')[0] + unet = UNet2DConditionModel.from_pretrained(model_name, subfolder='unet') + if pretrained: + controlnet = ControlNetModel.from_pretrained(controlnet) + else: + controlnet = ControlNetModel.from_unet(unet) + if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: + raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type) + + + if hasattr(controlnet, 'mid_block') and unet.controlnet is not None: + for attention in controlnet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in controlnet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + for block in controlnet.controlnet_down_blocks: + block._fsdp_wrap = True + + controlnet.controlnet_mid_block._fsdp_wrap = True + + # Make the composer model + model = ControlNet( + unet=unet, + vae=vae, + controlnet = controlnet, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + offset_noise=offset_noise, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + precomputed_latents=precomputed_latents, + encode_latents_in_fp16=encode_latents_in_fp16, + mask_pad_tokens=mask_pad_tokens, + fsdp=fsdp, + load_controlnet_from_composer=controlnet_from_composer_unet + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + model.controlnet.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed and use_xformers: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) + model.unet.set_attn_processor(attn_processor) + model.controlnet.set_attn_processor(attn_processor) + + return model + def stable_diffusion_xl( tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', @@ -474,6 +681,273 @@ def stable_diffusion_xl( return model +def stable_diffusion_xl_controlnet( + tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', + 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), + text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', + 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + controlnet: str = None, + pretrained: bool = True, + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + prediction_type: str = 'epsilon', + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + use_karras_sigmas: bool = False, + offset_noise: Optional[float] = None, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = True, + mask_pad_tokens: bool = False, + fsdp: bool = True, + clip_qkv: Optional[float] = None, + use_xformers: bool = True, + controlnet_from_composer_unet: bool = False, +): + """Stable diffusion 2 training setup + SDXL UNet and VAE. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. + + Args: + tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', + 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2')``. + text_encoder_names (str, Tuple[str, ...]): HuggingFace name(s) of the text encoder(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', + 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2')``. + unet_model_name (str): Name of the UNet model to load. Defaults to + 'stabilityai/stable-diffusion-xl-base-1.0'. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + pretrained (bool): Whether to load pretrained weights. Defaults to True. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + use_karras_sigmas (bool): Whether to use the Karras sigmas for the diffusion process noise. Default: `False`. + offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not + be used. Default `None`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False. + fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Improves stability of training. + Default: ``None``. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if (isinstance(tokenizer_names, tuple) or + isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names): + raise ValueError('Number of tokenizer_names and text_encoder_names must be equal') + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + # Make the tokenizer and text encoder + tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names) + text_encoder = MultiTextEncoder(model_names=text_encoder_names, + encode_latents_in_fp16=encode_latents_in_fp16, + pretrained_sdxl=pretrained) + + precision = torch.float16 if encode_latents_in_fp16 else None + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') + if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: + raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') + + if pretrained: + controlnet = ControlNetModel.from_pretrained(controlnet) + else: + controlnet = ControlNetModel.from_unet(unet) + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + assert isinstance(unet, UNet2DConditionModel) + if hasattr(unet, 'mid_block') and unet.mid_block is not None: + for attention in unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + if hasattr(controlnet, 'mid_block') and unet.controlnet is not None: + for attention in controlnet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in controlnet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + for block in controlnet.controlnet_down_blocks: + block._fsdp_wrap = True + + controlnet.controlnet_mid_block._fsdp_wrap = True + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + if beta_schedule == 'squaredcos_cap_v2': + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + rescale_betas_zero_snr=zero_terminal_snr) + else: + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + prediction_type=prediction_type, + interpolation_type='linear', + use_karras_sigmas=use_karras_sigmas, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + # Make the composer model + model = ControlNet( + unet=unet, + vae=vae, + controlnet = controlnet, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + offset_noise=offset_noise, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + precomputed_latents=precomputed_latents, + encode_latents_in_fp16=encode_latents_in_fp16, + mask_pad_tokens=mask_pad_tokens, + fsdp=fsdp, + sdxl=True, + load_controlnet_from_composer=controlnet_from_composer_unet + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + model.controlnet.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed and use_xformers: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) + model.unet.set_attn_processor(attn_processor) + model.controlnet.set_attn_processor(attn_processor) + + return model def build_autoencoder(input_channels: int = 3, output_channels: int = 3,