diff --git a/.gitignore b/.gitignore index 6c4e8c9..86ec20e 100644 --- a/.gitignore +++ b/.gitignore @@ -56,5 +56,6 @@ wandb pretrained-* tuning-* models -*.sh -grid.png \ No newline at end of file +grid.png +aesthetics_65/* +output/* \ No newline at end of file diff --git a/README.md b/README.md index 59b9265..dadbd04 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,8 @@ If your target image is your face, you need to pre-train on a large face image d Or, if you have an artistic image, you might want to train on WikiArt like so. ``` accelerate launch pretrain_e4t.py \ - --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ - --clip_model_name_or_path="ViT-H-14::laion2b_s32b_b79k" \ + --mixed_precision="fp16" \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ --domain_class_token="art" \ --placeholder_token="*s" \ --prompt_template="art" \ @@ -44,13 +44,12 @@ accelerate launch pretrain_e4t.py \ --train_image_dataset="Artificio/WikiArt" \ --iterable_dataset \ --resolution=512 \ - --train_batch_size=16 \ + --train_batch_size=1 \ --learning_rate=1e-6 --scale_lr \ --checkpointing_steps=10000 \ --log_steps=1000 \ --max_train_steps=100000 \ --unfreeze_clip_vision \ - --mixed_precision="fp16" \ --enable_xformers_memory_efficient_attention ``` diff --git a/download_aesthetics.py b/download_aesthetics.py new file mode 100644 index 0000000..20d90bd --- /dev/null +++ b/download_aesthetics.py @@ -0,0 +1,53 @@ +from io import BytesIO +import os +from datasets import load_dataset +import requests +import concurrent.futures +import tqdm +from PIL import Image + +output_folder = os.path.join(os.path.dirname(__file__), "aesthetics_65") +dataset = load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus") + + +urls = [] +for item in tqdm.tqdm(dataset["train"]): + urls.append(item["URL"]) + +finished = 0 +max_items = len(urls) + +def download_image(url, output_dir): + global finished + try: + # Extract the image filename from the URL. + filename = os.path.basename(url) + # Create the full output path for the image. + output_path = os.path.join(output_dir, filename) + # Download the image using the requests library. + response = requests.get(url) + response.raise_for_status() + # Save the image to the output path. + with open(output_path, 'wb') as f: + try: + img = Image.open(BytesIO(response.content)) + img.save(f, format=img.format) + except: + print(f"Failed to open {url}") + return + finished += 1 + print("Finished {}/{}".format(finished, max_items)) + except Exception as e: + print(f"Failed to download {url}. Error: {e}") + +def parallel_download_images(urls, output_dir): + # Create the output directory if it doesn't exist. + os.makedirs(output_dir, exist_ok=True) + # Use a ThreadPoolExecutor to download the images in parallel. + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit the download tasks to the executor. + futures = [executor.submit(download_image, url, output_dir) for url in urls] + # Wait for all the tasks to complete. + concurrent.futures.wait(futures) + +parallel_download_images(urls, output_folder) \ No newline at end of file diff --git a/download_diffusiondb.py b/download_diffusiondb.py new file mode 100644 index 0000000..7f45bc0 --- /dev/null +++ b/download_diffusiondb.py @@ -0,0 +1,214 @@ +# Author: Marco Lustri 2022 - https://github.com/TheLustriVA +# MIT License + +"""A script to make downloading the DiffusionDB dataset easier.""" +from urllib.error import HTTPError +from urllib.request import urlretrieve +from alive_progress import alive_bar +from os.path import exists + +import shutil +import os +import time +import argparse + +index = None # initiate main arguments as None +range_max = None +output = None +unzip = None +large = None + +parser = argparse.ArgumentParser(description="Download a file from a URL") # + +# It's adding arguments to the parser. +parser.add_argument( + "-i", + "--index", + type=int, + default=1, + help="File to download or lower bound of range if -r is set", +) +parser.add_argument( + "-r", + "--range", + type=int, + default=2000, + help="Upper bound of range if -i is provided", +) +parser.add_argument( + "-o", "--output", type=str, default="images", help="Output directory name" +) +parser.add_argument( + "-z", + "--unzip", + default=False, + help="Unzip the file after downloading", + # It's setting the argument to True if it's provided. + action="store_true", +) +parser.add_argument( + "-l", + "--large", + default=False, + help="Download from DiffusionDB Large (14 million images)", + action="store_true", +) + +args = parser.parse_args() # parse the arguments + +# It's checking if the user has provided any arguments, and if they have, it +# sets the variables to the arguments. +if args.index: + index = args.index +if args.range: + range_max = args.range +if args.output: + output = args.output +if args.unzip: + unzip = args.unzip +if args.large: + large = args.large + +if ( + args.index and args.range and args.output and args.unzip and args.large is None +): # if no arguments are provided, set default behaviour + index = 1 + range_max = 2000 + output = "images" + unzip = False + large = False + + +def download(index=1, range_index=0, output="", large=False): + """ + Download a file from a URL and save it to a local file + + :param index: The index of the file to download, defaults to 1 (optional) + :param range_index: The number of files to download. If you want to download + all files, set this to the number of files you want to download, + defaults to 0 (optional) + :param output: The directory to download the files to :return: A list of + files to unzip + :param large: If downloading from DiffusionDB Large (14 million images) + instead of DiffusionDB 2M (2 million images) + """ + baseurl = "https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/" + files_to_unzip = [] + + if large: + if index <= 10000: + url = f"{baseurl}diffusiondb-large-part-1/part-{index:06}.zip" + else: + url = f"{baseurl}diffusiondb-large-part-2/part-{index:06}.zip" + else: + url = f"{baseurl}images/part-{index:06}.zip" + + if output != "": + output = f"{output}/" + + if not exists(output): + os.makedirs(output) + + if range_index == 0: + print("Downloading file: ", url) + file_path = f"{output}part-{index:06}.zip" + try: + urlretrieve(url, file_path) + except HTTPError as e: + print(f"Encountered an HTTPError downloading file: {url} - {e}") + if unzip: + unzip(file_path) + else: + # It's downloading the files numbered from index to range_index. + with alive_bar(range_index - index, title="Downloading files") as bar: + for idx in range(index, range_index): + if large: + if idx <= 10000: + url = f"{baseurl}diffusiondb-large-part-1/part-{idx:06}.zip" + else: + url = f"{baseurl}diffusiondb-large-part-2/part-{idx:06}.zip" + else: + url = f"{baseurl}images/part-{idx:06}.zip" + + loop_file_path = f"{output}part-{idx:06}.zip" + # It's trying to download the file, and if it encounters an + # HTTPError, it prints the error. + try: + urlretrieve(url, loop_file_path) + except HTTPError as e: + print(f"HTTPError downloading file: {url} - {e}") + files_to_unzip.append(loop_file_path) + # It's writing the url of the file to a manifest file. + with open("manifest.txt", "a") as f: + f.write(url + "\n") + time.sleep(0.1) + bar() + + # It's checking if the user wants to unzip the files, and if they do, it + # returns a list of files to unzip. It would be a bad idea to put these + # together as the process is already lengthy. + if unzip and len(files_to_unzip) > 0: + return files_to_unzip + + +def unzip_file(file: str, output: str = None): + """ + > This function takes a zip file as an argument and unpacks it + + :param file: str + :type file: str + :return: The file name without the .zip extension + """ + shutil.unpack_archive(file, extract_dir=output) + return f"File: {file.replace('.zip', '')} has been unzipped" + + +def unzip_all(files: list, output: str = None): + """ + > Unzip all files in a list of files + + :param files: list + :type files: list + """ + with alive_bar(len(files), title="Unzipping files") as bar: + for file in files: + unzip_file(file, output) + time.sleep(0.1) + bar() + + +def main(index=None, range_max=None, output=None, unzip=None, large=None): + """ + `main` is a function that takes in an index, a range_max, an output, and an + unzip, and if the user confirms that they have enough space, it downloads + the files from the index to the output, and if unzip is true, it unzips them + + :param index: The index of the file you want to download + :param range_max: The number of files to download + :param output: The directory to download the files to + :param unzip: If you want to unzip the files after downloading them, set + this to True + :param large: If you want to download from DiffusionDB Large (14 million + images) instead of DiffusionDB 2M (2 million images) + :return: A list of files that have been downloaded + """ + if range_max - index > 1999: + confirmation = input("Do you have at least 1.7Tb free: (y/n)") + if confirmation != "y": + return + if index and range_max: + files = download(index, range_max, output, large) + if unzip: + unzip_all(files, output) + elif index: + download(index, output=output, large=large) + else: + print("No index provided") + + +# This is a common pattern in Python. It allows you to run the main function of +# your script by running the script through the interpreter. It also allows you +# to import the script into the interpreter without automatically running the +# main function. +if __name__ == "__main__": + main(index, range_max, output, unzip, large) \ No newline at end of file diff --git a/download_diffusiondb.sh b/download_diffusiondb.sh new file mode 100644 index 0000000..542ce21 --- /dev/null +++ b/download_diffusiondb.sh @@ -0,0 +1,5 @@ +python download_diffusiondb.py \ + --index 1 \ + --range 50 \ + --unzip \ + --output "/home/ubuntu/e4t-diffusion/diffusondb" \ \ No newline at end of file diff --git a/e4t/utils.py b/e4t/utils.py index cc86363..69c907a 100644 --- a/e4t/utils.py +++ b/e4t/utils.py @@ -22,7 +22,7 @@ def __getstate__(self): return self.obj.items() def __setstate__(self, items): - if not hasattr(self, 'obj'): + if not hasattr(self, "obj"): self.obj = {} for key, val in items: self.obj[key] = val @@ -43,11 +43,7 @@ def keys(self): def download_from_huggingface(repo, filename, **kwargs): while True: try: - return huggingface_hub.hf_hub_download( - repo, - filename=filename, - **kwargs - ) + return huggingface_hub.hf_hub_download(repo, filename=filename, **kwargs) except HTTPError as e: if e.response.status_code == 401: # Need to log into huggingface api @@ -76,13 +72,17 @@ def download_from_huggingface(repo, filename, **kwargs): def load_config_from_pretrained(pretrained_model_name_or_path): if os.path.exists(pretrained_model_name_or_path): if "config.json" not in pretrained_model_name_or_path: - pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, "config.json") + pretrained_model_name_or_path = os.path.join( + pretrained_model_name_or_path, "config.json" + ) else: - assert pretrained_model_name_or_path in MODELS, f"Choose from {list(MODELS.keys())}" + assert ( + pretrained_model_name_or_path in MODELS + ), f"Choose from {list(MODELS.keys())}" pretrained_model_name_or_path = download_from_huggingface( repo=MODELS[pretrained_model_name_or_path]["repo"], filename="config.json", - subfolder=MODELS[pretrained_model_name_or_path]["subfolder"] + subfolder=MODELS[pretrained_model_name_or_path]["subfolder"], ) with open(pretrained_model_name_or_path, "r", encoding="utf-8") as f: pretrained_args = AttributeDict(json.load(f)) @@ -91,9 +91,12 @@ def load_config_from_pretrained(pretrained_model_name_or_path): def load_e4t_unet(pretrained_model_name_or_path=None, ckpt_path=None, **kwargs): assert pretrained_model_name_or_path is not None or ckpt_path is not None - if pretrained_model_name_or_path is None or not os.path.exists(ckpt_path): + if pretrained_model_name_or_path is None: if os.path.exists(ckpt_path): - assert os.path.basename(ckpt_path) == "unet.pt" or os.path.basename(ckpt_path) == "weight_offsets.pt", "You must specify the filename! (`unet.pt` or `weight_offsets.pt`)" + assert ( + os.path.basename(ckpt_path) == "unet.pt" + or os.path.basename(ckpt_path) == "weight_offsets.pt" + ), "You must specify the filename! (`unet.pt` or `weight_offsets.pt`)" config = load_config_from_pretrained(os.path.dirname(ckpt_path)) else: assert ckpt_path in MODELS, f"Choose from {list(MODELS.keys())}" @@ -102,16 +105,22 @@ def load_e4t_unet(pretrained_model_name_or_path=None, ckpt_path=None, **kwargs): ckpt_path = download_from_huggingface( repo=MODELS[ckpt_path]["repo"], filename="weight_offsets.pt", - subfolder=MODELS[ckpt_path]["subfolder"] + subfolder=MODELS[ckpt_path]["subfolder"], ) except EntryNotFoundError: ckpt_path = download_from_huggingface( repo=MODELS[ckpt_path]["repo"], filename="unet.pt", - subfolder=MODELS[ckpt_path]["subfolder"] + subfolder=MODELS[ckpt_path]["subfolder"], ) - pretrained_model_name_or_path = config.pretrained_model_name_or_path if config.pretrained_args is None else config.pretrained_args["pretrained_model_name_or_path"] - unet = OriginalUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", **kwargs) + pretrained_model_name_or_path = ( + config.pretrained_model_name_or_path + if config.pretrained_args is None + else config.pretrained_args["pretrained_model_name_or_path"] + ) + unet = OriginalUNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", **kwargs + ) state_dict = dict(unet.state_dict()) if ckpt_path: ckpt_sd = torch.load(ckpt_path, map_location="cpu") @@ -142,7 +151,7 @@ def load_e4t_encoder(ckpt_path=None, **kwargs): ckpt_path = download_from_huggingface( repo=MODELS[ckpt_path]["repo"], filename="encoder.pt", - subfolder=MODELS[ckpt_path]["subfolder"] + subfolder=MODELS[ckpt_path]["subfolder"], ) state_dict = torch.load(ckpt_path, map_location="cpu") print(f"Resuming from {ckpt_path}") @@ -182,7 +191,7 @@ def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h)) + grid = Image.new("RGB", size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): diff --git a/inference.py b/inference.py index 37081d6..78a8b6d 100644 --- a/inference.py +++ b/inference.py @@ -39,6 +39,7 @@ def parse_args(): parser.add_argument("--guidance_scale", type=float, default=1.0, help="unconditional guidance scale") parser.add_argument("--num_images_per_prompt", type=int, default=1, help="number of images per prompt") parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",) + parser.add_argument("--output_dir", type=str, default="e4t-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",) parser.add_argument("--seed", type=int, default=None, help="the seed (for reproducible sampling)") parser.add_argument("--scheduler_type", type=str, choices=["ddim", "plms", "lms", "euler", "euler_ancestral", "dpm_solver++"], default="ddim", help="diffusion scheduler type") @@ -135,6 +136,7 @@ def main(): generator = torch.Generator(device=device).manual_seed(args.seed) prompts = args.prompt.split("::") all_images = [] + output_dir = args.output_dir for prompt in tqdm(prompts): with torch.autocast(device), torch.inference_mode(): images = pipe( @@ -148,8 +150,10 @@ def main(): width=args.width, ).images all_images.extend(images) + for i, img in enumerate(images): + img.save(f"{output_dir}/{prompt}_{i}.png") grid_image = image_grid(all_images, len(prompts), args.num_images_per_prompt) - grid_image.save("grid.png") + grid_image.save(f"{output_dir}/grid.png") print("DONE! See `grid.png` for the results!") diff --git a/inference.sh b/inference.sh new file mode 100644 index 0000000..b8ecfdd --- /dev/null +++ b/inference.sh @@ -0,0 +1,17 @@ +INPUT_PATH=$1 +PROMPT=$2 +INPUT_PATH="/home/ubuntu/e4t-diffusion/training_images/$INPUT_PATH" +PROJECT="diffusiondb" + +echo "Prompt: $PROMPT" +echo "Input path: $INPUT_PATH" + +python inference.py \ + --pretrained_model_name_or_path "./output/$PROJECT/100/" \ + --prompt "$PROMPT" \ + --num_images_per_prompt 8 \ + --scheduler_type "ddim" \ + --output_dir "./output/$PROJECT" \ + --image_path_or_url="$INPUT_PATH" \ + --num_inference_steps 50 \ + --guidance_scale 7.5 \ No newline at end of file diff --git a/pretrain.sh b/pretrain.sh new file mode 100644 index 0000000..5b0555f --- /dev/null +++ b/pretrain.sh @@ -0,0 +1,19 @@ +accelerate launch pretrain_e4t.py \ + --mixed_precision="fp16" \ + --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ + --domain_class_token="art" \ + --placeholder_token="*s" \ + --prompt_template="art" \ + --save_sample_prompt="a photo in the style of *s,artwork in the style of *s" \ + --reg_lambda=0.02 \ + --domain_embed_scale=0.2 \ + --output_dir="pretrained-diffusiondb" \ + --train_image_dataset="/home/ubuntu/e4t-diffusion/diffusondb/" \ + --resolution=512 \ + --train_batch_size=1 \ + --learning_rate=2e-6 \ + --scale_lr \ + --checkpointing_steps=10000 \ + --log_steps=1000 \ + --max_train_steps=100000 \ + --enable_xformers_memory_efficient_attention \ No newline at end of file diff --git a/pretrain_config.yaml b/pretrain_config.yaml new file mode 100644 index 0000000..d0ce4d4 --- /dev/null +++ b/pretrain_config.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/pretrain_e4t.py b/pretrain_e4t.py index fb4b62a..0998d07 100644 --- a/pretrain_e4t.py +++ b/pretrain_e4t.py @@ -10,7 +10,7 @@ import itertools import numpy as np -from PIL import Image +from PIL import Image, UnidentifiedImageError import albumentations from einops import rearrange import torch @@ -32,6 +32,7 @@ from e4t.pipeline_stable_diffusion_e4t import StableDiffusionE4TPipeline from e4t.utils import load_e4t_unet, load_e4t_encoder, save_e4t_unet, save_e4t_encoder, image_grid +wandb.login(key="9b29b8d267c2014b101852fc7faaaa3fb0b8bcbe") templates = [ "a photo of {placeholder_token}", @@ -165,12 +166,16 @@ def __init__( def __len__(self): return len(self.dataset) - def __getitem__(self, idx): + def __getitem__(self,idx): image = self.dataset[idx] if self.from_datasets: image = image["image"] else: - image = Image.open(image) + try: + image = Image.open(image) + except UnidentifiedImageError as e: + print(f"UnidentifiedImageError: {image}") + return self.__getitem__(idx + 1) image = np.array(image.convert("RGB")) image = self.processor(image=image)["image"] image = (image / 127.5 - 1.0).astype(np.float32) diff --git a/requirements.txt b/requirements.txt index d5ec826..4b1a436 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ safetensors datasets bitsandbytes webdataset -open-clip-torch \ No newline at end of file +open-clip-torch +alive_progress \ No newline at end of file diff --git a/training_images/aesthetics/beach.jpeg b/training_images/aesthetics/beach.jpeg new file mode 100644 index 0000000..1ef1b1c Binary files /dev/null and b/training_images/aesthetics/beach.jpeg differ diff --git a/training_images/aesthetics/city.jpeg b/training_images/aesthetics/city.jpeg new file mode 100644 index 0000000..229310f Binary files /dev/null and b/training_images/aesthetics/city.jpeg differ diff --git a/training_images/art/fantasy.jpeg b/training_images/art/fantasy.jpeg new file mode 100644 index 0000000..03085d7 Binary files /dev/null and b/training_images/art/fantasy.jpeg differ diff --git a/training_images/art/matisse.jpg b/training_images/art/matisse.jpg new file mode 100644 index 0000000..c5b3ad4 Binary files /dev/null and b/training_images/art/matisse.jpg differ diff --git a/training_images/art/monet.jpg b/training_images/art/monet.jpg new file mode 100644 index 0000000..1f4916f Binary files /dev/null and b/training_images/art/monet.jpg differ diff --git a/training_images/art/monet2.jpg b/training_images/art/monet2.jpg new file mode 100644 index 0000000..2d59dbd Binary files /dev/null and b/training_images/art/monet2.jpg differ diff --git a/training_images/art/picasso.jpg b/training_images/art/picasso.jpg new file mode 100644 index 0000000..6d8839f Binary files /dev/null and b/training_images/art/picasso.jpg differ diff --git a/training_images/art/vangogh.jpg b/training_images/art/vangogh.jpg new file mode 100644 index 0000000..ac3417f Binary files /dev/null and b/training_images/art/vangogh.jpg differ diff --git a/training_images/isometric/isometric1.png b/training_images/isometric/isometric1.png new file mode 100644 index 0000000..88603d2 Binary files /dev/null and b/training_images/isometric/isometric1.png differ diff --git a/training_images/isometric/isometric2.png b/training_images/isometric/isometric2.png new file mode 100644 index 0000000..a47719d Binary files /dev/null and b/training_images/isometric/isometric2.png differ diff --git a/training_images/isometric/isometric3.png b/training_images/isometric/isometric3.png new file mode 100644 index 0000000..7350bbb Binary files /dev/null and b/training_images/isometric/isometric3.png differ diff --git a/training_images/isometric/isometric4.png b/training_images/isometric/isometric4.png new file mode 100644 index 0000000..7350bbb Binary files /dev/null and b/training_images/isometric/isometric4.png differ diff --git a/training_images/isometric/isometric5.png b/training_images/isometric/isometric5.png new file mode 100644 index 0000000..4424a47 Binary files /dev/null and b/training_images/isometric/isometric5.png differ diff --git a/training_images/isometric/isometric6.png b/training_images/isometric/isometric6.png new file mode 100644 index 0000000..5add1c3 Binary files /dev/null and b/training_images/isometric/isometric6.png differ diff --git a/training_images/isometric/isometric7.png b/training_images/isometric/isometric7.png new file mode 100644 index 0000000..b14823c Binary files /dev/null and b/training_images/isometric/isometric7.png differ diff --git a/training_images/isometric/isometric8.png b/training_images/isometric/isometric8.png new file mode 100644 index 0000000..0dbeecb Binary files /dev/null and b/training_images/isometric/isometric8.png differ diff --git a/training_images/mj/1.png b/training_images/mj/1.png new file mode 100644 index 0000000..b612269 Binary files /dev/null and b/training_images/mj/1.png differ diff --git a/training_images/mj/2.png b/training_images/mj/2.png new file mode 100644 index 0000000..33d9847 Binary files /dev/null and b/training_images/mj/2.png differ diff --git a/training_images/mj/3.png b/training_images/mj/3.png new file mode 100644 index 0000000..554fbc7 Binary files /dev/null and b/training_images/mj/3.png differ diff --git a/training_images/mj/4.png b/training_images/mj/4.png new file mode 100644 index 0000000..e3b940a Binary files /dev/null and b/training_images/mj/4.png differ diff --git a/training_images/mj/5.png b/training_images/mj/5.png new file mode 100644 index 0000000..f292be1 Binary files /dev/null and b/training_images/mj/5.png differ diff --git a/tuning.sh b/tuning.sh new file mode 100644 index 0000000..ab3209c --- /dev/null +++ b/tuning.sh @@ -0,0 +1,16 @@ +INPUT_PATH=$1 +INPUT_PATH="/home/ubuntu/e4t-diffusion/training_images/$INPUT_PATH" +PROJECT="diffusiondb" + +accelerate launch tuning_e4t.py \ + --pretrained_model_name_or_path="/home/ubuntu/e4t-diffusion/pretrained-diffusiondb/100000/" \ + --reg_lambda=1e-4 \ + --output_dir "./output/$PROJECT" \ + --train_image_path="$INPUT_PATH" \ + --resolution=256 \ + --train_batch_size=4 \ + --learning_rate=1e-6 \ + --scale_lr \ + --max_train_steps=100 \ + --mixed_precision="fp16" \ + --enable_xformers_memory_efficient_attention \ No newline at end of file