diff --git a/diffusion/callbacks/__init__.py b/diffusion/callbacks/__init__.py index 48727c72..d9b35885 100644 --- a/diffusion/callbacks/__init__.py +++ b/diffusion/callbacks/__init__.py @@ -7,8 +7,10 @@ from diffusion.callbacks.log_latent_statistics import LogLatentStatistics from diffusion.callbacks.nan_catcher import NaNCatcher from diffusion.callbacks.scheduled_garbage_collector import ScheduledGarbageCollector +from diffusion.callbacks.assign_controlnet_weight import AssignControlNet __all__ = [ + 'AssignControlNet', 'LogAutoencoderImages', 'LogDiffusionImages', 'LogLatentStatistics', diff --git a/diffusion/callbacks/assign_controlnet_weight.py b/diffusion/callbacks/assign_controlnet_weight.py new file mode 100644 index 00000000..76dc8a74 --- /dev/null +++ b/diffusion/callbacks/assign_controlnet_weight.py @@ -0,0 +1,49 @@ +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from composer import Callback, Logger, State +from composer.core import get_precision_context +from torch.nn.parallel import DistributedDataParallel +from diffusers import ControlNetModel, UNet2DConditionModel + +class AssignControlNet(Callback): + """Assigns Controlnet weights to the controlnet from the Unet after composer loads the checkpoint + + Args: + use_fsdp: whether or not the model is FSDP wrapped + """ + + def __init__(self, use_fsdp): + self.use_fsdp = use_fsdp + + def process_controlnet(self, controlnet: ControlNetModel, unet: UNet2DConditionModel): + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + if hasattr(controlnet, "add_embedding"): + controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + def after_load(self, state: State, logger: Logger): + # Get the model object if it has been wrapped by DDP to access the image generation function. + if isinstance(state.model, DistributedDataParallel): + model = state.model.module + else: + model = state.model + + # Load checkpoint + if model.load_controlnet_from_composer: + with get_precision_context(state.precision): + if self.use_fsdp: + with FSDP.summon_full_params(model.unet, recurse = True, writeback = False): + with FSDP.summon_full_params(model.controlnet, recurse = True, writeback = True): + self.process_controlnet(model.controlnet, model.unet) + + else: + self.process_controlnet(model.controlnet, model.unet) + \ No newline at end of file diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 1a4bb0a2..74c04b4a 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -4,8 +4,8 @@ """Diffusion models.""" from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion, - discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, - text_to_image_transformer) + discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, stable_diffusion_2_controlnet, + stable_diffusion_xl_controlnet, text_to_image_transformer) from diffusion.models.noop import NoOpModel from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion @@ -19,6 +19,8 @@ 'PixelDiffusion', 'stable_diffusion_2', 'stable_diffusion_xl', + 'stable_diffusion_2_controlnet', + 'stable_diffusion_xl_controlnet', 'StableDiffusion', 'text_to_image_transformer', ] diff --git a/diffusion/models/controlnet.py b/diffusion/models/controlnet.py new file mode 100644 index 00000000..79409fbd --- /dev/null +++ b/diffusion/models/controlnet.py @@ -0,0 +1,617 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""ControlNet Diffusion models.""" + +from contextlib import nullcontext +from typing import List, Optional, Tuple, Union +import PIL +import torch +import torch.nn.functional as F +from composer.models import ComposerModel +from composer.utils import dist +from scipy.stats import qmc +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm +from diffusers.image_processor import VaeImageProcessor + +class ControlNet(ComposerModel): + def __init__(self, + unet, + vae, + text_encoder, + tokenizer, + controlnet, + noise_scheduler, + inference_noise_scheduler, + loss_fn=F.mse_loss, + prediction_type: str = 'epsilon', + latent_mean: Optional[Tuple[float]] = None, + latent_std: Optional[Tuple[float]] = None, + downsample_factor: int = 8, + offset_noise: Optional[float] = None, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + image_key: str = 'image', + text_key: str = 'captions', + image_latents_key: str = 'image_latents', + text_latents_key: str = 'caption_latents', + controlnet_images_key: str = 'controlnet_images', + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = False, + mask_pad_tokens: bool = False, + fsdp: bool = False, + sdxl: bool = False, + load_controlnet_from_composer: bool = False): + super().__init__() + self.unet = unet + self.vae = vae + self.load_controlnet_from_composer = load_controlnet_from_composer + self.controlnet = controlnet + self.noise_scheduler = noise_scheduler + self.loss_fn = loss_fn + self.prediction_type = prediction_type.lower() + if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: + raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') + self.downsample_factor = downsample_factor + self.offset_noise = offset_noise + self.quasirandomness = quasirandomness + self.train_seed = train_seed + self.val_seed = val_seed + self.image_key = image_key + self.controlnet_images_key = controlnet_images_key + self.image_latents_key = image_latents_key + self.precomputed_latents = precomputed_latents + self.mask_pad_tokens = mask_pad_tokens + self.sdxl = sdxl + if latent_mean is None: + self.latent_mean = 4 * (0.0) + if latent_std is None: + self.latent_std = 4 * (1 / 0.13025,) if self.sdxl else 4 * (1 / 0.18215,) + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.train_metrics = train_metrics if train_metrics is not None else [MeanSquaredError()] + self.val_metrics = val_metrics if val_metrics is not None else [MeanSquaredError()] + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.inference_scheduler = inference_noise_scheduler + self.text_key = text_key + self.text_latents_key = text_latents_key + self.encode_latents_in_fp16 = encode_latents_in_fp16 + self.mask_pad_tokens = mask_pad_tokens + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.downsample_factor, do_convert_rgb=True, do_normalize=False + ) + # freeze text_encoder during diffusion training + self.text_encoder.requires_grad_(False) + self.vae.requires_grad_(False) + self.unet.requires_grad_(False) + self.controlnet.train() + if self.encode_latents_in_fp16: + self.text_encoder = self.text_encoder.half() + self.vae = self.vae.half() + if fsdp: + # only wrap models in the diffusion process + self.text_encoder._fsdp_wrap = False + self.vae._fsdp_wrap = False + self.unet._fsdp_wrap = True + self.controlnet._fsdp_wrap = True + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + if self.quasirandomness: + self.sobol = qmc.Sobol(d=1, scramble=True, seed=self.train_seed) + + def _apply(self, fn): + super(ControlNet, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + + def _generate_timesteps(self, latents: torch.Tensor): + if self.quasirandomness: + # Generate a quasirandom sequence of timesteps equal to the global batch size + global_batch_size = latents.shape[0] * dist.get_world_size() + sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) + timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() + timesteps = torch.floor(timesteps).long() + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + timesteps = torch.randint(0, + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) + return timesteps + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def forward(self, batch): + latents, text_embeds, text_pooled_embeds, attention_mask, encoder_attention_mask = None, None, None, None, None + if 'attention_mask' in batch and self.mask_pad_tokens: + attention_mask = batch['attention_mask'] # mask for text encoders + encoder_attention_mask = _create_unet_attention_mask(attention_mask) # text mask for U-Net + + # Use latents if specified and available. When specified, they might not exist during eval + if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed latents') + latents, text_embeds = batch[self.image_latents_key], batch[self.text_latents_key] + else: + inputs, conditionings = batch[self.image_key], batch[self.text_key] + + # If encode_latents_in_fp16, disable autocast context as models are in fp16 + c = torch.cuda.amp.autocast(enabled=False) if self.encode_latents_in_fp16 else nullcontext() # type: ignore + with c: + # Encode the images to the latent space. + if self.encode_latents_in_fp16: + latents = self.vae.encode(inputs.half())['latent_dist'].sample().data + else: + latents = self.vae.encode(inputs)['latent_dist'].sample().data + # Encode tokenized prompt into embedded text and pooled text embeddings + text_encoder_out = self.text_encoder(conditionings, attention_mask=attention_mask) + text_embeds = text_encoder_out[0] + if self.sdxl: + if len(text_encoder_out) <= 1: + raise RuntimeError('SDXL requires text encoder output to include a pooled text embedding') + text_pooled_embeds = text_encoder_out[1] + + # Scale the latents + latents = (latents - self.latent_mean) / self.latent_std + + # Zero dropped captions if needed + if 'drop_caption_mask' in batch.keys(): + text_embeds *= batch['drop_caption_mask'].view(-1, 1, 1) + if text_pooled_embeds is not None: + text_pooled_embeds *= batch['drop_caption_mask'].view(-1, 1) + + # Sample the diffusion timesteps + timesteps = self._generate_timesteps(latents) + # Add noise to the inputs (forward diffusion) + noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) + if self.offset_noise is not None: + offset_noise = torch.randn(latents.shape[0], + latents.shape[1], + 1, + 1, + device=noise.device, + generator=self.rng_generator) + noise += self.offset_noise * offset_noise + noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + c = torch.cuda.amp.autocast(enabled=False) if self.encode_latents_in_fp16 else nullcontext() # type: ignore + with c: + # Encode the images to the latent space. + if self.encode_latents_in_fp16: + controlnet_image = batch[self.controlnet_images_key].half() + else: + controlnet_image = batch[self.controlnet_images_key] + + down_block_sample, mid_block_sample = self.controlnet( + noised_latents, + timesteps, + text_embeds, + controlnet_image, + attention_mask=encoder_attention_mask, + return_dict=False + ) + # Generate the targets + if self.prediction_type == 'epsilon': + targets = noise + elif self.prediction_type == 'sample': + targets = latents + elif self.prediction_type == 'v_prediction': + targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl: + add_time_ids = torch.cat( + [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) + added_cond_kwargs = {'text_embeds': text_pooled_embeds, 'time_ids': add_time_ids} + + # Forward through the model + return self.unet(noised_latents, + timesteps, + text_embeds, + encoder_attention_mask=encoder_attention_mask, + down_block_additional_residuals = down_block_sample, + mid_block_additional_residual = mid_block_sample, + added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps + + def loss(self, outputs, batch): + """Loss between unet output and added noise, typically mse.""" + return self.loss_fn(outputs[0], outputs[1]) + + def eval_forward(self, batch, outputs=None): + """For stable diffusion, eval forward computes unet outputs as well as some samples.""" + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics = self.train_metrics + else: + metrics = self.val_metrics + metrics_dict = {metric.__class__.__name__: metric for metric in metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + metric.update(outputs[0], outputs[1]) + + @torch.no_grad() + def generate( + self, + image: Union[PIL.Image.Image, torch.Tensor, List[PIL.Image.Image], List[torch.Tensor]] = None, + prompt: Optional[list] = None, + negative_prompt: Optional[list] = None, + tokenized_prompts: Optional[torch.LongTensor] = None, + tokenized_negative_prompts: Optional[torch.LongTensor] = None, + tokenized_prompts_pad_mask: Optional[torch.LongTensor] = None, + tokenized_negative_prompts_pad_mask: Optional[torch.LongTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 3.0, + rescaled_guidance: Optional[float] = None, + num_images_per_prompt: Optional[int] = 1, + seed: Optional[int] = None, + progress_bar: Optional[bool] = True, + zero_out_negative_prompt: bool = True, + crop_params: Optional[torch.Tensor] = None, + input_size_params: Optional[torch.Tensor] = None, + controlnet_conditioning_scale: Optional[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: Optional[float] = 0.0, + control_guidance_end: Optional[float] = 1.0, + ): + """Generates image from noise. + + Performs the backward diffusion process, each inference step takes + one forward pass through the unet. + + Args: + prompt (str or List[str]): The prompt or prompts to guide the image generation. + image (PIL.Image.Image, torch.Tensor, List[PIL.Image.Image], List[torch.Tensor): the images + for controlnet guidance. + negative_prompt (str or List[str]): The prompt or prompts to guide the + image generation away from. Ignored when not using guidance + (i.e., ignored if guidance_scale is less than 1). + Must be the same length as list of prompts. Default: `None`. + tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If SDXL, this will be a tensor of size [B, 2, max_length], + otherwise will be of shape [B, max_length]. Default: `None`. + tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative + prompts instead of string prompts. Default: `None`. + tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for + pre-tokenized prompts. Default `None`. + tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for + pre-tokenized negative prompts. Default `None`. + prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If both prompt and prompt_embeds + are passed, prompt_embeds will be used. Default: `None`. + negative_prompt_embeds (torch.FloatTensor): Optionally pass pre-embedded negative + prompts instead of string negative prompts. If both negative_prompt and + negative_prompt_embeds are passed, prompt_embeds will be used. Default: `None`. + height (int, optional): The height in pixels of the generated image. + Default: `self.unet.config.sample_size * 8)`. + width (int, optional): The width in pixels of the generated image. + Default: `self.unet.config.sample_size * 8)`. + num_inference_steps (int): The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense + of slower inference. Default: `50`. + guidance_scale (float): Guidance scale as defined in + Classifier-Free Diffusion Guidance. guidance_scale is defined as w of equation + 2. of Imagen Paper. Guidance scale is enabled by setting guidance_scale > 1. + Higher guidance scale encourages to generate images that are closely linked + to the text prompt, usually at the expense of lower image quality. + Default: `3.0`. + rescaled_guidance (float, optional): Rescaled guidance scale. If not specified, rescaled guidance will + not be used. Default: `None`. + num_images_per_prompt (int): The number of images to generate per prompt. + Default: `1`. + progress_bar (bool): Whether to use the tqdm progress bar during generation. + Default: `True`. + seed (int): Random seed to use for generation. Set a seed for reproducible generation. + Default: `None`. + zero_out_negative_prompt (bool): Whether or not to zero out negative prompt if it is + an empty string. Default: `True`. + crop_params (torch.FloatTensor of size [Bx2], optional): Crop parameters to use + when generating images with SDXL. Default: `None`. + input_size_params (torch.FloatTensor of size [Bx2], optional): Size parameters + (representing original size of input image) to use when generating images with SDXL. + Default: `None`. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + """ + _check_prompt_given(prompt, tokenized_prompts, prompt_embeds) + _check_prompt_lenths(prompt, negative_prompt) + _check_prompt_lenths(tokenized_prompts, tokenized_negative_prompts) + _check_prompt_lenths(prompt_embeds, negative_prompt_embeds) + + guess_mode = guess_mode or self.controlnet.config.global_pool_conditions + # Create rng for the generation + device = self.vae.device + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) # type: ignore + + height = height or self.unet.config.sample_size * self.downsample_factor + width = width or self.unet.config.sample_size * self.downsample_factor + assert height is not None # for type checking + assert width is not None # for type checking + + do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore + + text_embeddings, pooled_text_embeddings, pad_attn_mask = self._prepare_text_embeddings( + prompt, tokenized_prompts, tokenized_prompts_pad_mask, prompt_embeds, num_images_per_prompt) + batch_size = len(text_embeddings) # len prompts * num_images_per_prompt + # classifier free guidance + negative prompts + # negative prompt is given in place of the unconditional input in classifier free guidance + pooled_embeddings, encoder_attn_mask = pooled_text_embeddings, pad_attn_mask + if do_classifier_free_guidance: + if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt: + # Negative prompt is empty and we want to zero it out + unconditional_embeddings = torch.zeros_like(text_embeddings) + pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None + uncond_pad_attn_mask = torch.zeros_like(pad_attn_mask) if pad_attn_mask is not None else None + else: + if not negative_prompt: + negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore + unconditional_embeddings, pooled_unconditional_embeddings, uncond_pad_attn_mask = self._prepare_text_embeddings( + negative_prompt, tokenized_negative_prompts, tokenized_negative_prompts_pad_mask, + negative_prompt_embeds, num_images_per_prompt) + + # concat uncond + prompt + text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]) + if self.sdxl: + pooled_embeddings = torch.cat([pooled_unconditional_embeddings, pooled_text_embeddings]) # type: ignore + if pad_attn_mask is not None: + encoder_attn_mask = torch.cat([uncond_pad_attn_mask, pad_attn_mask]) # type: ignore + else: + if self.sdxl: + pooled_embeddings = pooled_text_embeddings + + image = self._prepare_image(image, width, height, batch_size, num_images_per_prompt, device, self.controlnet.dtype, do_classifier_free_guidance, guess_mode) + height, width = image.shape[-2:] + + # prepare for diffusion generation process + latents = torch.randn( + (batch_size, self.unet.config.in_channels, height // self.downsample_factor, + width // self.downsample_factor), + device=device, + dtype=self.unet.dtype, + generator=rng_generator, + ) + + self.inference_scheduler.set_timesteps(num_inference_steps) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.inference_scheduler.init_noise_sigma + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl and pooled_embeddings is not None: + if crop_params is None: + crop_params = torch.zeros((batch_size, 2), dtype=text_embeddings.dtype) + if input_size_params is None: + input_size_params = torch.tensor([width, height], dtype=text_embeddings.dtype).repeat(batch_size, 1) + output_size_params = torch.tensor([width, height], dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if do_classifier_free_guidance: + crop_params = torch.cat([crop_params, crop_params]) + input_size_params = torch.cat([input_size_params, input_size_params]) + output_size_params = torch.cat([output_size_params, output_size_params]) + + add_time_ids = torch.cat([input_size_params, crop_params, output_size_params], dim=1).to(device) + added_cond_kwargs = {'text_embeds': pooled_embeddings, 'time_ids': add_time_ids} + + use_controlnet = [] + for i in range(num_inference_steps): + use_controlnet.append(1.0 - float(i / num_inference_steps < control_guidance_start or (i + 1) / num_inference_steps > control_guidance_end)) + index = 0 + # backward diffusion process + for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + + latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": pooled_embeddings.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + cond_scale = controlnet_conditioning_scale*use_controlnet[index] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # Model prediction + pred = self.unet(latent_model_input, + t, + encoder_hidden_states=text_embeddings, + encoder_attention_mask=encoder_attn_mask, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs).sample + + if do_classifier_free_guidance: + # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_text, dim=(1, 2, 3), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2, 3), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample + index += 1 + + # We now use the vae to decode the generated latents back into the image. + # scale and decode the image latents with vae + latents = latents * self.latent_std + self.latent_mean + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image.detach() # (batch*num_images_per_prompt, channel, h, w) + + def _prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def _prepare_text_embeddings(self, prompt, tokenized_prompts, tokenized_pad_mask, prompt_embeds, + num_images_per_prompt): + """Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt.""" + device = self.text_encoder.device + pooled_text_embeddings = None + if prompt_embeds is None: + if tokenized_prompts is None: + tokenized_out = self.tokenizer(prompt, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_prompts = tokenized_out['input_ids'] + if self.mask_pad_tokens: + tokenized_pad_mask = tokenized_out['attention_mask'] + else: + tokenized_pad_mask = None + if tokenized_pad_mask is not None: + tokenized_pad_mask = tokenized_pad_mask.to(device) + text_encoder_out = self.text_encoder(tokenized_prompts.to(device), attention_mask=tokenized_pad_mask) + prompt_embeds = text_encoder_out[0] + if self.sdxl: + if len(text_encoder_out) <= 1: + raise RuntimeError('SDXL requires text encoder output to include a pooled text embedding') + pooled_text_embeddings = text_encoder_out[1] + else: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed embeddings') + + # duplicate text embeddings for each generation per prompt + prompt_embeds = _duplicate_tensor(prompt_embeds, num_images_per_prompt) + + if not self.mask_pad_tokens: + tokenized_pad_mask = None + + if tokenized_pad_mask is not None: + tokenized_pad_mask = _create_unet_attention_mask(tokenized_pad_mask) + tokenized_pad_mask = _duplicate_tensor(tokenized_pad_mask, num_images_per_prompt) + + if self.sdxl and pooled_text_embeddings is not None: + pooled_text_embeddings = _duplicate_tensor(pooled_text_embeddings, num_images_per_prompt) + return prompt_embeds, pooled_text_embeddings, tokenized_pad_mask + +def _check_prompt_lenths(prompt, negative_prompt): + if prompt is None and negative_prompt is None: + return + batch_size = 1 if isinstance(prompt, str) else len(prompt) + if negative_prompt: + negative_prompt_bs = 1 if isinstance(negative_prompt, str) else len(negative_prompt) + if negative_prompt_bs != batch_size: + raise ValueError('len(prompts) and len(negative_prompts) must be the same. \ + A negative prompt must be provided for each given prompt.') + + +def _check_prompt_given(prompt, tokenized_prompts, prompt_embeds): + if prompt is None and tokenized_prompts is None and prompt_embeds is None: + raise ValueError('Must provide one of `prompt`, `tokenized_prompts`, or `prompt_embeds`') + + +def _create_unet_attention_mask(attention_mask): + """Takes the union of multiple attention masks if given more than one mask.""" + if len(attention_mask.shape) == 2: + return attention_mask + elif len(attention_mask.shape) == 3: + encoder_attention_mask = attention_mask[:, 0] + for i in range(1, attention_mask.shape[1]): + encoder_attention_mask |= attention_mask[:, i] + return encoder_attention_mask + else: + raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_mask.shape}') + + +def _duplicate_tensor(tensor, num_images_per_prompt): + """Duplicate tensor for multiple generations from a single prompt.""" + batch_size, seq_len = tensor.shape[:2] + tensor = tensor.repeat(1, num_images_per_prompt, *[ + 1, + ] * len(tensor.shape[2:])) + return tensor.view(batch_size * num_images_per_prompt, seq_len, *[ + -1, + ] * len(tensor.shape[2:])) \ No newline at end of file diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 4326cf1e..3d79cf78 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -9,7 +9,7 @@ import torch from composer.devices import DeviceGPU -from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel, ControlNetModel from torchmetrics import MeanSquaredError from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig @@ -20,6 +20,8 @@ from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer +from diffusion.models.controlnet import ControlNet + from diffusion.models.transformer import DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -232,6 +234,212 @@ def stable_diffusion_2( return model +def stable_diffusion_2_controlnet( + model_name: str = 'stabilityai/stable-diffusion-2-base', + pretrained: bool = True, + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + prediction_type: str = 'epsilon', + controlnet_from_composer_unet: bool = False, + controlnet: Optional[str] = None, + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 5.489980785067252, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + offset_noise: Optional[float] = None, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = True, + mask_pad_tokens: bool = False, + fsdp: bool = True, + clip_qkv: Optional[float] = None, + use_xformers: bool = True, +): + """Stable diffusion v2 training setup. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. + + Args: + model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'. + pretrained (bool): Whether to load pretrained weights. Defaults to True. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (float, list, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, list, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.18215`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not + be used. Default `None`. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False. + fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + precision = torch.float16 if encode_latents_in_fp16 else None + # Make the text encoder + text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=precision) + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + # Use the pretrained vae + downsample_factor = 8 + vae = AutoencoderKL.from_pretrained(model_name, subfolder='vae', torch_dtype=precision) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(model_name, subfolder='unet')[0] + unet = UNet2DConditionModel.from_pretrained(model_name, subfolder='unet') + if pretrained: + controlnet = ControlNetModel.from_pretrained(controlnet) + else: + controlnet = ControlNetModel.from_unet(unet) + if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: + raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type) + + + if hasattr(controlnet, 'mid_block') and unet.controlnet is not None: + for attention in controlnet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in controlnet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + for block in controlnet.controlnet_down_blocks: + block._fsdp_wrap = True + + controlnet.controlnet_mid_block._fsdp_wrap = True + + # Make the composer model + model = ControlNet( + unet=unet, + vae=vae, + controlnet = controlnet, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + offset_noise=offset_noise, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + precomputed_latents=precomputed_latents, + encode_latents_in_fp16=encode_latents_in_fp16, + mask_pad_tokens=mask_pad_tokens, + fsdp=fsdp, + load_controlnet_from_composer=controlnet_from_composer_unet + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + model.controlnet.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed and use_xformers: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) + model.unet.set_attn_processor(attn_processor) + model.controlnet.set_attn_processor(attn_processor) + + return model + def stable_diffusion_xl( tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', @@ -498,6 +706,273 @@ def stable_diffusion_xl( return model +def stable_diffusion_xl_controlnet( + tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', + 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), + text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', + 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + controlnet: str = None, + pretrained: bool = True, + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + prediction_type: str = 'epsilon', + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + use_karras_sigmas: bool = False, + offset_noise: Optional[float] = None, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + precomputed_latents: bool = False, + encode_latents_in_fp16: bool = True, + mask_pad_tokens: bool = False, + fsdp: bool = True, + clip_qkv: Optional[float] = None, + use_xformers: bool = True, + controlnet_from_composer_unet: bool = False, +): + """Stable diffusion 2 training setup + SDXL UNet and VAE. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. + + Args: + tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', + 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2')``. + text_encoder_names (str, Tuple[str, ...]): HuggingFace name(s) of the text encoder(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', + 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2')``. + unet_model_name (str): Name of the UNet model to load. Defaults to + 'stabilityai/stable-diffusion-xl-base-1.0'. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + pretrained (bool): Whether to load pretrained weights. Defaults to True. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + use_karras_sigmas (bool): Whether to use the Karras sigmas for the diffusion process noise. Default: `False`. + offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not + be used. Default `None`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False. + fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Improves stability of training. + Default: ``None``. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if (isinstance(tokenizer_names, tuple) or + isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names): + raise ValueError('Number of tokenizer_names and text_encoder_names must be equal') + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + # Make the tokenizer and text encoder + tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names) + text_encoder = MultiTextEncoder(model_names=text_encoder_names, + encode_latents_in_fp16=encode_latents_in_fp16, + pretrained_sdxl=pretrained) + + precision = torch.float16 if encode_latents_in_fp16 else None + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') + if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: + raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') + + if pretrained: + controlnet = ControlNetModel.from_pretrained(controlnet) + else: + controlnet = ControlNetModel.from_unet(unet) + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + assert isinstance(unet, UNet2DConditionModel) + if hasattr(unet, 'mid_block') and unet.mid_block is not None: + for attention in unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + if hasattr(controlnet, 'mid_block') and unet.controlnet is not None: + for attention in controlnet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in controlnet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in controlnet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + for block in controlnet.controlnet_down_blocks: + block._fsdp_wrap = True + + controlnet.controlnet_mid_block._fsdp_wrap = True + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + if beta_schedule == 'squaredcos_cap_v2': + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + rescale_betas_zero_snr=zero_terminal_snr) + else: + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + prediction_type=prediction_type, + interpolation_type='linear', + use_karras_sigmas=use_karras_sigmas, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + # Make the composer model + model = ControlNet( + unet=unet, + vae=vae, + controlnet = controlnet, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + offset_noise=offset_noise, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + precomputed_latents=precomputed_latents, + encode_latents_in_fp16=encode_latents_in_fp16, + mask_pad_tokens=mask_pad_tokens, + fsdp=fsdp, + sdxl=True, + load_controlnet_from_composer=controlnet_from_composer_unet + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + model.controlnet.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed and use_xformers: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) + model.unet.set_attn_processor(attn_processor) + model.controlnet.set_attn_processor(attn_processor) + + return model def text_to_image_transformer( tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'),