Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def stable_diffusion_xl(
use_xformers: bool = True,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
cache_dir: str = '/tmp/hf_files',
local_files_only: bool = False,
):
"""Stable diffusion 2 training setup + SDXL UNet and VAE.

Expand Down Expand Up @@ -364,6 +366,9 @@ def stable_diffusion_xl(
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`.
local_files_only (bool): Whether to only use local files. Default: `False`.

"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

Expand All @@ -377,10 +382,14 @@ def stable_diffusion_xl(
val_metrics = [MeanSquaredError()]

# Make the tokenizer and text encoder
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names,
cache_dir=cache_dir,
local_files_only=local_files_only)
text_encoder = MultiTextEncoder(model_names=text_encoder_names,
encode_latents_in_fp16=encode_latents_in_fp16,
pretrained_sdxl=pretrained)
pretrained_sdxl=pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only)

precision = torch.float16 if encode_latents_in_fp16 else None
# Make the autoencoder
Expand Down Expand Up @@ -408,9 +417,15 @@ def stable_diffusion_xl(
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)

# Make the unet
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
subfolder='unet',
cache_dir=cache_dir,
local_files_only=local_files_only)[0]
if pretrained:
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet')
unet = UNet2DConditionModel.from_pretrained(unet_model_name,
subfolder='unet',
cache_dir=cache_dir,
local_files_only=local_files_only)
if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4:
raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.')
else:
Expand Down Expand Up @@ -612,6 +627,7 @@ def precomputed_text_latent_diffusion(
use_xformers: bool = True,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
local_files_only: bool = False,
):
"""Latent diffusion model training using precomputed text latents from T5-XXL and CLIP.

Expand Down Expand Up @@ -662,6 +678,7 @@ def precomputed_text_latent_diffusion(
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
local_files_only (bool): Whether to only use local files. Default: `False`.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

Expand Down Expand Up @@ -695,7 +712,10 @@ def precomputed_text_latent_diffusion(
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)

# Make the unet
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
subfolder='unet',
cache_dir=cache_dir,
local_files_only=local_files_only)[0]

if isinstance(vae, AutoEncoder):
# Adapt the unet config to account for differing number of latent channels if necessary
Expand Down Expand Up @@ -792,20 +812,22 @@ def precomputed_text_latent_diffusion(
if include_text_encoders:
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
dtype = dtype_map[text_encoder_dtype]
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl',
cache_dir=cache_dir,
local_files_only=local_files_only)
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='tokenizer',
cache_dir=cache_dir,
local_files_only=False)
local_files_only=local_files_only)
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).encoder.eval()
local_files_only=local_files_only).encoder.eval()
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='text_encoder',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).cuda().eval()
local_files_only=local_files_only).cuda().eval()
# Make the composer model
model = PrecomputedTextLatentDiffusion(
unet=unet,
Expand Down
49 changes: 35 additions & 14 deletions diffusion/models/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ class MultiTextEncoder(torch.nn.Module):
the projected output from a CLIPTextModelWithProjection. Default: ``False``.
"""

def __init__(
self,
model_names: Union[str, Tuple[str, ...]],
model_dim_keys: Optional[Union[str, List[str]]] = None,
encode_latents_in_fp16: bool = True,
pretrained_sdxl: bool = False,
):
def __init__(self,
model_names: Union[str, Tuple[str, ...]],
model_dim_keys: Optional[Union[str, List[str]]] = None,
encode_latents_in_fp16: bool = True,
pretrained_sdxl: bool = False,
cache_dir: str = '/tmp/hf_files',
local_files_only: bool = False):
super().__init__()
self.pretrained_sdxl = pretrained_sdxl

Expand All @@ -50,7 +50,10 @@ def __init__(
name_split = model_name.split('/')
base_name = '/'.join(name_split[:2])
subfolder = '/'.join(name_split[2:])
text_encoder_config = PretrainedConfig.get_config_dict(base_name, subfolder=subfolder)[0]
text_encoder_config = PretrainedConfig.get_config_dict(base_name,
subfolder=subfolder,
cache_dir=cache_dir,
local_files_only=local_files_only)[0]

# Add text_encoder output dim to total dim
dim_found = False
Expand All @@ -70,14 +73,25 @@ def __init__(
architectures = text_encoder_config['architectures']
if architectures == ['CLIPTextModel']:
self.text_encoders.append(
CLIPTextModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
CLIPTextModel.from_pretrained(base_name,
subfolder=subfolder,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only))
elif architectures == ['CLIPTextModelWithProjection']:
self.text_encoders.append(
CLIPTextModelWithProjection.from_pretrained(base_name, subfolder=subfolder,
torch_dtype=torch_dtype))
CLIPTextModelWithProjection.from_pretrained(base_name,
subfolder=subfolder,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only))
else:
self.text_encoders.append(
AutoModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
AutoModel.from_pretrained(base_name,
subfolder=subfolder,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only))
self.architectures += architectures

@property
Expand Down Expand Up @@ -125,7 +139,10 @@ class MultiTokenizer:
"org_name/repo_name/subfolder" where the subfolder is excluded if it is not used in the repo.
"""

def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
def __init__(self,
tokenizer_names_or_paths: Union[str, Tuple[str, ...]],
cache_dir: str = '/tmp/hf_files',
local_files_only: bool = False):
if isinstance(tokenizer_names_or_paths, str):
tokenizer_names_or_paths = (tokenizer_names_or_paths,)

Expand All @@ -134,7 +151,11 @@ def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
path_split = tokenizer_name_or_path.split('/')
base_name = '/'.join(path_split[:2])
subfolder = '/'.join(path_split[2:])
self.tokenizers.append(AutoTokenizer.from_pretrained(base_name, subfolder=subfolder))
self.tokenizers.append(
AutoTokenizer.from_pretrained(base_name,
subfolder=subfolder,
cache_dir=cache_dir,
local_files_only=local_files_only))

self.model_max_length = min([t.model_max_length for t in self.tokenizers])

Expand Down
Loading