diff --git a/Dockerfile b/Dockerfile index ce358274d71141688361c323c84f760395ea240d..e756b5bcb7b0028948fd5a24830e5c65dfb34aca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,22 +1,28 @@ - FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime WORKDIR /app + +# System deps RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* -# Install ai-toolkit -RUN git clone https://github.com/ostris/ai-toolkit.git /app/ai-toolkit -WORKDIR /app/ai-toolkit -RUN git submodule update --init --recursive -RUN pip install --no-cache-dir -e . +# Pre-baked ai-toolkit (numpy 2.5 / dctorch / torchdiffeq / torchsde / clip forks, 3.14-compatible) +# Copied from local checkout so HF doesn't have to clone the submodule / run pip install -e . +COPY ai-toolkit /app/ai-toolkit + +# Make ai-toolkit importable. Upstream ai-toolkit ships without setup.py / pyproject.toml, +# so pip install -e . would fail. We add it to PYTHONPATH instead. +ENV PYTHONPATH=/app:/app/ai-toolkit +ENV PYTHONUNBUFFERED=1 +ENV HF_HOME=/app/hf_cache +ENV TRANSFORMERS_CACHE=/app/hf_cache -# Install HF Hub -RUN pip install --no-cache-dir huggingface_hub +# Install runtime deps +RUN pip install --no-cache-dir huggingface_hub hf_transfer # Copy training files COPY . /app/ -# Pre-download FLUX model and assistant LoRA -RUN python -c "from huggingface_hub import snapshot_download; snapshot_download('Niansuh/FLUX.1-schnell', cache_dir='/app/hf_cache'); snapshot_download('ostris/FLUX.1-schnell-training-adapter', cache_dir='/app/hf_cache')" +# Pre-download FLUX base + training adapter at build time so they're in the image cache +RUN python -c "import os; os.environ['HF_HUB_ENABLE_HF_TRANSFER']='1'; from huggingface_hub import snapshot_download; snapshot_download('Niansuh/FLUX.1-schnell'); snapshot_download('ostris/FLUX.1-schnell-training-adapter')" CMD ["python", "/app/train_cloud.py"] diff --git a/ai-toolkit/.gitignore b/ai-toolkit/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e6485fd40f8b2d35a69d8cbfb7433e7f962ebb2d --- /dev/null +++ b/ai-toolkit/.gitignore @@ -0,0 +1,187 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +.python +.node +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +/env.sh +/models +/datasets +/custom/* +!/custom/.gitkeep +/.tmp +/venv.bkp +/venv.* +/config/* +!/config/examples +!/config/_PUT_YOUR_CONFIGS_HERE).txt +/output/* +!/output/.gitkeep +/extensions/* +!/extensions/example +/temp +/wandb +.vscode/settings.json +.DS_Store +._.DS_Store +aitk_db.db +/notes.md +/data +.claude \ No newline at end of file diff --git a/ai-toolkit/.gitmodules b/ai-toolkit/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/FAQ.md b/ai-toolkit/FAQ.md new file mode 100644 index 0000000000000000000000000000000000000000..a13ba46585573c3dc7c4cde94422633f251f04bf --- /dev/null +++ b/ai-toolkit/FAQ.md @@ -0,0 +1,10 @@ +# FAQ + +WIP. Will continue to add things as they are needed. + +## FLUX.1 Training + +#### How much VRAM is required to train a lora on FLUX.1? + +24GB minimum is required. + diff --git a/ai-toolkit/LICENSE b/ai-toolkit/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d72f95d548901698a309fcf56d64c086ddc264cf --- /dev/null +++ b/ai-toolkit/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Ostris, LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ai-toolkit/README.md b/ai-toolkit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c1db4b2557c226940bf89d46634d07eca8bf128e --- /dev/null +++ b/ai-toolkit/README.md @@ -0,0 +1,316 @@ +# Ostris AI Toolkit + +AI Toolkit is an easy to use all in one training suite for diffusion models. I try to support all the latest models on consumer grade hardware. Image and video models. It can be run as a GUI or CLI. It is designed to be easy to use but still have every feature imaginable. Free and open source. + + + +## Supported Models + +### Image +- [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (FLUX.1) +- [black-forest-labs/FLUX.2-dev](https://huggingface.co/black-forest-labs/FLUX.2-dev) (FLUX.2) +- [black-forest-labs/FLUX.2-klein-base-4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) (FLUX.2-klein-base-4B) +- [black-forest-labs/FLUX.2-klein-base-9B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B) (FLUX.2-klein-base-9B) +- [ostris/Flex.1-alpha](https://huggingface.co/ostris/Flex.1-alpha) (Flex.1) +- [ostris/Flex.2-preview](https://huggingface.co/ostris/Flex.2-preview) (Flex.2) +- [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) (Chroma) +- [Alpha-VLLM/Lumina-Image-2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) (Lumina2) +- [Qwen/Qwen-Image](https://huggingface.co/Qwen/Qwen-Image) (Qwen-Image) +- [Qwen/Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512) (Qwen-Image-2512) +- [HiDream-ai/HiDream-I1-Full](https://huggingface.co/HiDream-ai/HiDream-I1-Full) (HiDream I1) +- [OmniGen2/OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) (OmniGen2) +- [Tongyi-MAI/Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) (Z-Image Turbo) +- [Tongyi-MAI/Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image) (Z-Image) +- [ostris/Z-Image-De-Turbo](https://huggingface.co/ostris/Z-Image-De-Turbo) (Z-Image De-Turbo) +- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) (SDXL) +- [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) (SD 1.5) +- [baidu/ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image) (ERNIE-Image) +- [NucleusAI/Nucleus-Image](https://huggingface.co/NucleusAI/Nucleus-Image) (Nucleus-Image) +- [HiDream-ai/HiDream-O1-Image](https://huggingface.co/HiDream-ai/HiDream-O1-Image) (HiDream O1) +- [Photoroom/prxpixel-t2i](https://huggingface.co/Photoroom/prxpixel-t2i) (PRXPixel) + +### Instruction / Edit +- [black-forest-labs/FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) (FLUX.1-Kontext-dev) +- [Qwen/Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit) (Qwen-Image-Edit) +- [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) (Qwen-Image-Edit-2509) +- [Qwen/Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) (Qwen-Image-Edit-2511) +- [HiDream-ai/HiDream-E1-1](https://huggingface.co/HiDream-ai/HiDream-E1-1) (HiDream E1) + +### Video +- [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) (Wan 2.1 1.3B) +- [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers) (Wan 2.1 I2V 14B-480P) +- [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers) (Wan 2.1 I2V 14B-720P) +- [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) (Wan 2.1 14B) +- [Wan-AI/Wan2.2-T2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) (Wan 2.2 14B) +- [Wan-AI/Wan2.2-I2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) (Wan 2.2 I2V 14B) +- [Wan-AI/Wan2.2-TI2V-5B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) (Wan 2.2 TI2V 5B) +- [Lightricks/LTX-2](https://huggingface.co/Lightricks/LTX-2) (LTX-2) +- [Lightricks/LTX-2.3](https://huggingface.co/Lightricks/LTX-2.3) (LTX-2.3) +- [krea/Krea-2-Raw](https://huggingface.co/krea/Krea-2-Raw) (Krea 2) + +### Audio +- [ACE-Step/Ace-Step1.5](https://huggingface.co/ACE-Step/Ace-Step1.5) (Ace Step 1.5) +- [ACE-Step/acestep-v15-xl-base](https://huggingface.co/ACE-Step/acestep-v15-xl-base) (Ace Step 1.5 XL) + +### Experimental +- [lodestones/Zeta-Chroma](https://huggingface.co/lodestones/Zeta-Chroma) (Zeta Chroma) +- [ideogram-ai/ideogram-4-fp8](https://huggingface.co/ideogram-ai/ideogram-4-fp8) (Ideogram 4 FP8) + +## Installation + +Requirements: +- python >=3.10 (3.12 recommended) +- Nvidia GPU with enough ram to do what you need +- python venv +- git + + +Linux: +```bash +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +python3 -m venv venv +source venv/bin/activate +# install torch first +pip3 install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128 +pip3 install -r requirements.txt +``` + +For devices running **DGX OS** (including DGX Spark), follow [these](dgx_instructions.md) instructions. + + +Windows: + +If you are having issues with Windows. I recommend using the easy install script at [https://github.com/Tavris1/AI-Toolkit-Easy-Install](https://github.com/Tavris1/AI-Toolkit-Easy-Install) + +```bash +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +python -m venv venv +.\venv\Scripts\activate +pip install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128 +pip install -r requirements.txt +``` + +MacOS: + +Experimental support for Silicon Macs is available. I do not have a Mac with enough RAM to fully test this +so please let me know if there are issues. There is a convience script to install and run on MacOS +locates at `./run_mac.zsh` that will install the dependencies locally and run the UI. To run this, +do the following: + +```bash +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +chmod +x run_mac.zsh +./run_mac.zsh +``` + + +# AI Toolkit UI + +AI Toolkit UI + +The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server. + +## Running the UI + +Requirements: +- Node.js > 20 + +The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below +will install / update the UI and it's dependencies and start the UI. + +```bash +cd ui +npm run build_and_start +``` + +You can now access the UI at `http://localhost:8675` or `http://:8675` if you are running it on a server. + +## Securing the UI + +If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token. +You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access +the UI. You can set this when starting the UI like so: + +```bash +# Linux +AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start + +# Windows +set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start + +# Windows Powershell +$env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start +``` + +### Training +1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml` +2. Edit the file following the comments in the file +3. Run the file like so `python run.py config/whatever_you_want.yml` + +A folder with the name and the training folder from the config file will be created when you start. It will have all +checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up +from the last checkpoint. + +IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving + +### Need help? + +Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU) +and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord +and I will answer when I can. + +## Ostris Cloud + +You can use many cloud providers to rent GPUs. If you want to help support this project in the largest way possible, please consider using [Ostris Cloud](https://cloud.ostris.com). Ostris Cloud is owned and operated by me, Ostris, and every dollar earned goes directly back into funding the development of this project. + +Ostris Cloud + + +## Training in RunPod +If you would like to use Runpod, but have not signed up yet, please consider using [my Runpod affiliate link](https://runpod.io?ref=h0y9jyr2) to help support this project. + + +I maintain an official Runpod Pod template here which can be accessed [here](https://console.runpod.io/deploy?template=0fqzfjy6f3&ref=h0y9jyr2). + +I have also created a short video showing how to get started using AI Toolkit with Runpod [here](https://youtu.be/HBNeS-F6Zz8). + +## Training in Modal + +### 1. Setup +#### ai-toolkit: +``` +git clone https://github.com/ostris/ai-toolkit.git +cd ai-toolkit +git submodule update --init --recursive +python -m venv venv +source venv/bin/activate +pip install torch +pip install -r requirements.txt +pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues +``` +#### Modal: +- Run `pip install modal` to install the modal Python package. +- Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`). + +#### Hugging Face: +- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev). +- Run `huggingface-cli login` and paste your token. + +### 2. Upload your dataset +- Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`. + +### 3. Configs +- Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```. +- Edit the config following the comments in the file, **be careful and follow the example `/root/ai-toolkit` paths**. + +### 4. Edit run_modal.py +- Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like: + + ``` + code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit") + ``` +- Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_. + +### 5. Training +- Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`. +- You can monitor your training in your local terminal, or on [modal.com](https://modal.com/). +- Models, samples and optimizer will be stored in `Storage > flux-lora-models`. + +### 6. Saving the model +- Check contents of the volume by running `modal volume ls flux-lora-models`. +- Download the content by running `modal volume get flux-lora-models your-model-name`. +- Example: `modal volume get flux-lora-models my_first_flux_lora_v1`. + +### Screenshot from Modal + +Modal Traning Screenshot + +--- + +## Dataset Preparation + +Datasets generally need to be a folder containing images and associated text files. Currently, the only supported +formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images +but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption. +You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically +replaced. + +Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**. +The loader will automatically resize them and can handle varying aspect ratios. + + +## Training Specific Layers + +To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers +used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your +network kwargs like so: + +```yaml + network: + type: "lora" + linear: 128 + linear_alpha: 128 + network_kwargs: + only_if_contains: + - "transformer.single_transformer_blocks.7.proj_out" + - "transformer.single_transformer_blocks.20.proj_out" +``` + +The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal +the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights. +For instance to only train the `single_transformer` for FLUX.1, you can use the following: + +```yaml + network: + type: "lora" + linear: 128 + linear_alpha: 128 + network_kwargs: + only_if_contains: + - "transformer.single_transformer_blocks." +``` + +You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks, + + +```yaml + network: + type: "lora" + linear: 128 + linear_alpha: 128 + network_kwargs: + ignore_if_contains: + - "transformer.single_transformer_blocks." +``` + +`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both, +if will be ignored. + +## LoKr Training + +To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so: + +```yaml + network: + type: "lokr" + lokr_full_rank: true + lokr_factor: 8 +``` + +Everything else should work the same including layer targeting. + + +## Support My Work + +If you enjoy my projects or use them commercially, please consider sponsoring me. Every bit helps! 💖 + +Support my work + +### Current Sponsors + +All of these people / organizations are the ones who selflessly make this project possible. Thank you!! + +Sponsors diff --git a/ai-toolkit/build_and_push_docker b/ai-toolkit/build_and_push_docker new file mode 100644 index 0000000000000000000000000000000000000000..e891e22c8db0930c446b6d353fe87773fdfbaedc --- /dev/null +++ b/ai-toolkit/build_and_push_docker @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# Extract version from version.py +if [ -f "version.py" ]; then + VERSION=$(python3 -c "from version import VERSION; print(VERSION)") + echo "Building version: $VERSION" +else + echo "Error: version.py not found. Please create a version.py file with VERSION defined." + exit 1 +fi + +echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." +echo "Building version: $VERSION and latest" +# wait 2 seconds +sleep 2 + +# Build the image with cache busting +docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile . + +# Tag with version and latest +docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION +docker tag aitoolkit:$VERSION ostris/aitoolkit:latest + +# Push both tags +echo "Pushing images to Docker Hub..." +docker push ostris/aitoolkit:$VERSION +docker push ostris/aitoolkit:latest + +echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest" \ No newline at end of file diff --git a/ai-toolkit/build_and_push_docker_dev b/ai-toolkit/build_and_push_docker_dev new file mode 100644 index 0000000000000000000000000000000000000000..9098d8cdf6f77e369ad922e863a37990c1548e38 --- /dev/null +++ b/ai-toolkit/build_and_push_docker_dev @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +VERSION=dev +GIT_COMMIT=dev + +echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." +echo "Building version: $VERSION" +# wait 2 seconds +sleep 2 + +# Build the image with cache busting +docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile . + +# Tag with version and latest +docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION + +# Push both tags +echo "Pushing images to Docker Hub..." +docker push ostris/aitoolkit:$VERSION + +echo "Successfully built and pushed ostris/aitoolkit:$VERSION" \ No newline at end of file diff --git a/ai-toolkit/config/examples/extract.example.yml b/ai-toolkit/config/examples/extract.example.yml new file mode 100644 index 0000000000000000000000000000000000000000..52505bb9058d81d4be3881bb20ecf6da214f571f --- /dev/null +++ b/ai-toolkit/config/examples/extract.example.yml @@ -0,0 +1,75 @@ +--- +# this is in yaml format. You can use json if you prefer +# I like both but yaml is easier to read and write +# plus it has comments which is nice for documentation +job: extract # tells the runner what to do +config: + # the name will be used to create a folder in the output folder + # it will also replace any [name] token in the rest of this config + name: name_of_your_model + # can be hugging face model, a .ckpt, or a .safetensors + base_model: "/path/to/base/model.safetensors" + # can be hugging face model, a .ckpt, or a .safetensors + extract_model: "/path/to/model/to/extract/trained.safetensors" + # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model + output_folder: "/path/to/output/folder" + is_v2: false + dtype: fp16 # saved dtype + device: cpu # cpu, cuda:0, etc + + # processes can be chained like this to run multiple in a row + # they must all use same models above, but great for testing different + # sizes and typed of extractions. It is much faster as we already have the models loaded + process: + # process 1 + - type: locon # locon or lora (locon is lycoris) + filename: "[name]_64_32.safetensors" # will be put in output folder + dtype: fp16 + mode: fixed + linear: 64 + conv: 32 + + # process 2 + - type: locon + output_path: "/absolute/path/for/this/output.safetensors" # can be absolute + mode: ratio + linear: 0.2 + conv: 0.2 + + # process 3 + - type: locon + filename: "[name]_ratio_02.safetensors" + mode: quantile + linear: 0.5 + conv: 0.5 + + # process 4 + - type: lora # traditional lora extraction (lierla) with linear layers only + filename: "[name]_4.safetensors" + mode: fixed # fixed, ratio, quantile supported for lora as well + linear: 4 # lora dim or rank + # no conv for lora + + # process 5 + - type: lora + filename: "[name]_q05.safetensors" + mode: quantile + linear: 0.5 + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. diff --git a/ai-toolkit/config/examples/generate.example.yaml b/ai-toolkit/config/examples/generate.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a3e19efdfee6f6215f37bb741a7b12e7c5ad484 --- /dev/null +++ b/ai-toolkit/config/examples/generate.example.yaml @@ -0,0 +1,60 @@ +--- + +job: generate # tells the runner what to do +config: + name: "generate" # this is not really used anywhere currently but required by runner + process: + # process 1 + - type: to_folder # process images to a folder + output_folder: "output/gen" + device: cuda:0 # cpu, cuda:0, etc + generate: + # these are your defaults you can override most of them with flags + sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now + width: 1024 + height: 1024 + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: -1 # -1 is random + guidance_scale: 7 + sample_steps: 20 + ext: ".png" # .png, .jpg, .jpeg, .webp + + # here ate the flags you can use for prompts. Always start with + # your prompt first then add these flags after. You can use as many + # like + # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20 + # we will try to support all sd-scripts flags where we can + + # FROM SD-SCRIPTS + # --n Treat everything until the next option as a negative prompt. + # --w Specify the width of the generated image. + # --h Specify the height of the generated image. + # --d Specify the seed for the generated image. + # --l Specify the CFG scale for the generated image. + # --s Specify the number of steps during generation. + + # OURS and some QOL additions + # --p2 Prompt for the second text encoder (SDXL only) + # --n2 Negative prompt for the second text encoder (SDXL only) + # --gr Specify the guidance rescale for the generated image (SDXL only) + # --seed Specify the seed for the generated image same as --d + # --cfg Specify the CFG scale for the generated image same as --l + # --steps Specify the number of steps during generation same as --s + + prompt_file: false # if true a txt file will be created next to images with prompt strings used + # prompts can also be a path to a text file with one prompt per line + # prompts: "/path/to/prompts.txt" + prompts: + - "photo of batman" + - "photo of superman" + - "photo of spiderman" + - "photo of a superhero --n batman superman spiderman" + + model: + # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt + # name_or_path: "runwayml/stable-diffusion-v1-5" + name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors" + is_v2: false # for v2 models + is_v_pred: false # for v-prediction models (most v2 models) + is_xl: false # for SDXL models + dtype: bf16 diff --git a/ai-toolkit/config/examples/mod_lora_scale.yaml b/ai-toolkit/config/examples/mod_lora_scale.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f59ecc838b1e4a4465600a67ddb690331ba3255 --- /dev/null +++ b/ai-toolkit/config/examples/mod_lora_scale.yaml @@ -0,0 +1,48 @@ +--- +job: mod +config: + name: name_of_your_model_v1 + process: + - type: rescale_lora + # path to your current lora model + input_path: "/path/to/lora/lora.safetensors" + # output path for your new lora model, can be the same as input_path to replace + output_path: "/path/to/lora/output_lora_v1.safetensors" + # replaces meta with the meta below (plus minimum meta fields) + # if false, we will leave the meta alone except for updating hashes (sd-script hashes) + replace_meta: true + # how to adjust, we can scale the up_down weights or the alpha + # up_down is the default and probably the best, they will both net the same outputs + # would only affect rare NaN cases and maybe merging with old merge tools + scale_target: 'up_down' + # precision to save, fp16 is the default and standard + save_dtype: fp16 + # current_weight is the ideal weight you use as a multiplier when using the lora + # IE in automatic1111 the 6.0 is the current_weight + # you can do negatives here too if you want to flip the lora + current_weight: 6.0 + # target_weight is the ideal weight you use as a multiplier when using the lora + # instead of the one above. IE in automatic1111 instead of using + # we want to use so 1.0 is the target_weight + target_weight: 1.0 + + # base model for the lora + # this is just used to add meta so automatic111 knows which model it is for + # assume v1.5 if these are not set + is_xl: false + is_v2: false +meta: + # this is only used if you set replace_meta to true above + name: "[name]" # [name] gets replaced with the name above + description: A short description of your lora + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. diff --git a/ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml b/ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51873de0b0b5e0e86153cc19c7561ba18272ec34 --- /dev/null +++ b/ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml @@ -0,0 +1,96 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir: + - folder_path: "/root/ai-toolkit/your-dataset" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + # if you get an error, or get stuck while downloading, + # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and + # place it like "/root/ai-toolkit/FLUX.1-dev" + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true # run 8bit mixed precision +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml b/ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d1e964fe9479e333566fd89ee8a4b2724231346 --- /dev/null +++ b/ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml @@ -0,0 +1,98 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir: + - folder_path: "/root/ai-toolkit/your-dataset" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + # if you get an error, or get stuck while downloading, + # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and + # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter" + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training + is_flux: true + quantize: true # run 8bit mixed precision + # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_flex_redux.yaml b/ai-toolkit/config/examples/train_flex_redux.yaml new file mode 100644 index 0000000000000000000000000000000000000000..918de842792903f09ec08a3e55dfad59958cccf1 --- /dev/null +++ b/ai-toolkit/config/examples/train_flex_redux.yaml @@ -0,0 +1,112 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flex_redux_finetune_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + adapter: + type: "redux" + # you can finetune an existing adapter or start from scratch. Set to null to start from scratch + name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors' + # name_or_path: null + # image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter + image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter + # image_encoder_arch: 'siglip' # for Flux.1 + image_encoder_arch: 'siglip2' + # You need a control input for each sample. Best to do squares for both images + test_img_path: + - "/path/to/x_01.jpg" + - "/path/to/x_02.jpg" + - "/path/to/x_03.jpg" + - "/path/to/x_04.jpg" + - "/path/to/x_05.jpg" + - "/path/to/x_06.jpg" + - "/path/to/x_07.jpg" + - "/path/to/x_08.jpg" + - "/path/to/x_09.jpg" + - "/path/to/x_10.jpg" + clip_layer: 'last_hidden_state' + train: true + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter) + # for normal redux, we are just recreating the same image, so you can use the same folder path above + clip_image_path: "/path/to/control/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions + train: + # this is what I used for the 24GB card, but feel free to adjust + # total batch size is 6 here + batch_size: 3 + gradient_accumulation: 2 + + # captions are not needed for this training, we cache a blank proompt and rely on the vision encoder + unload_text_encoder: true + + loss_type: "mse" + train_unet: true + train_text_encoder: false + steps: 4000000 # I set this very high and stop when I like the results + content_or_style: balanced # content, style, balanced + gradient_checkpointing: true + noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a" + timestep_type: "flux_shift" + optimizer: "adamw8bit" + lr: 1e-4 + + # this is for Flex.1, comment this out for FLUX.1-dev + bypass_guidance_embedding: true + + dtype: bf16 + ema_config: + use_ema: true + ema_decay: 0.99 + model: + name_or_path: "ostris/Flex.1-alpha" + is_flux: true + quantize: true + text_encoder_bits: 8 + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + # I leave half blank to test prompt and unprompted + prompts: + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "" + - "" + - "" + - "" + - "" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 + network_multiplier: 1.0 + +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_full_fine_tune_flex.yaml b/ai-toolkit/config/examples/train_full_fine_tune_flex.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f386a3b1e18aa126093f0542817a5b968ed8bb09 --- /dev/null +++ b/ai-toolkit/config/examples/train_full_fine_tune_flex.yaml @@ -0,0 +1,107 @@ +--- +# This configuration requires 48GB of VRAM or more to operate +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flex_finetune_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps + # performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # trigger_word: "p3r5on" + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 2 # how many intermittent saves to keep + save_format: 'diffusers' # 'diffusers' + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + # cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions + train: + batch_size: 1 + # IMPORTANT! For Flex, you must bypass the guidance embedder during training + bypass_guidance_embedding: true + + # can be 'sigmoid', 'linear', or 'lognorm_blend' + timestep_type: 'sigmoid' + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with flex + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adafactor" + lr: 3e-5 + + # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0. + # 0.1 is 10% of paramiters active at easc step. Only works with adafactor + + # do_paramiter_swapping: true + # paramiter_swapping_factor: 0.9 + + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flex, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "ostris/Flex.1-alpha" + is_flux: true # flex is flux architecture + # full finetuning quantized models is a crapshoot and results in subpar outputs + # quantize: true + # you can quantize just the T5 text encoder here to save vram + quantize_te: true + # only train the transformer blocks + only_if_contains: + - "transformer.transformer_blocks." + - "transformer.single_transformer_blocks." + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flex + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml b/ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51a61737237afc016c135d2994158a9e547edbda --- /dev/null +++ b/ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml @@ -0,0 +1,99 @@ +--- +# This configuration requires 24GB of VRAM or more to operate +job: extension +config: + # this name will be the folder and filename name + name: "my_first_lumina_finetune_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps + # performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # trigger_word: "p3r5on" + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 2 # how many intermittent saves to keep + save_format: 'diffusers' # 'diffusers' + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + # cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions + train: + batch_size: 1 + + # can be 'sigmoid', 'linear', or 'lumina2_shift' + timestep_type: 'lumina2_shift' + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with lumina2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adafactor" + lr: 3e-5 + + # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0. + # 0.1 is 10% of paramiters active at easc step. Only works with adafactor + + # do_paramiter_swapping: true + # paramiter_swapping_factor: 0.9 + + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram + # ema_config: + # use_ema: true + # ema_decay: 0.99 + + # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Alpha-VLLM/Lumina-Image-2.0" + is_lumina2: true # lumina2 architecture + # you can quantize just the Gemma2 text encoder here to save vram + quantize_te: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear." + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4.0 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_chroma_24gb.yaml b/ai-toolkit/config/examples/train_lora_chroma_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..807f2bbab6d2e13e6e7a34a35262164d203eb066 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_chroma_24gb.yaml @@ -0,0 +1,104 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_chroma_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with chroma + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for chroma, other dtypes may not work correctly + dtype: bf16 + model: + # Download the whichever model you prefer from the Chroma repo + # https://huggingface.co/lodestones/Chroma/tree/main + # point to it here. + # name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors" + + # using lodestones/Chroma will automatically use the latest version + name_or_path: "lodestones/Chroma" + + # # You can also select a version of Chroma like so + # name_or_path: "lodestones/Chroma/v28" + + arch: "chroma" + quantize: true # run 8bit mixed precision + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # negative prompt, optional + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_flex2_24gb.yaml b/ai-toolkit/config/examples/train_lora_flex2_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc698014e9c7a0aec0e951dc1a15d6b59e14cb40 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_flex2_24gb.yaml @@ -0,0 +1,165 @@ +# Note, Flex2 is a highly experimental WIP model. Finetuning a model with built in controls and inpainting has not +# been done before, so you will be experimenting with me on how to do it. This is my recommended setup, but this is highly +# subject to change as we learn more about how Flex2 works. + +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flex2_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # Flex2 is trained with controls and inpainting. If you want the model to truely understand how the + # controls function with your dataset, it is a good idea to keep doing controls during training. + # this will automatically generate the controls for you before training. The current script is not + # fully optimized so this could be rather slow for large datasets, but it caches them to disk so it + # only needs to be done once. If you want to skip this step, you can set the controls to [] and it will + controls: + - "depth" + - "line" + - "pose" + - "inpaint" + + # you can make custom inpainting images as well. These images must be webp or png format with an alpha. + # just erase the part of the image you want to inpaint and save it as a webp or png. Again, erase your + # train target. So the person if training a person. The automatic controls above with inpaint will + # just run a background remover mask and erase the foreground, which works well for subjects. + + # inpaint_path: "/my/impaint/images" + + # you can also specify existing control image pairs. It can handle multiple groups and will randomly + # select one for each step. + + # control_path: + # - "/my/custom/control/images" + # - "/my/custom/control/images2" + + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # flex2 enjoys multiple resolutions + train: + batch_size: 1 + # IMPORTANT! For Flex2, you must bypass the guidance embedder during training + bypass_guidance_embedding: true + + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with flex2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + # shift works well for training fast and learning composition and style. + # for just subject, you may want to change this to sigmoid + timestep_type: 'shift' # 'linear', 'sigmoid', 'shift' + optimizer: "adamw8bit" + lr: 1e-4 + + optimizer_params: + weight_decay: 1e-5 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Defaults off + ema_config: + use_ema: false + ema_decay: 0.99 + + # will probably need this if gpu supports it for flex, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "ostris/Flex.2-preview" + arch: "flex2" + quantize: true # run 8bit mixed precision + quantize_te: true + + # you can pass special training infor for controls to the model here + # percentages are decimal based so 0.0 is 0% and 1.0 is 100% of the time. + model_kwargs: + # inverts the inpainting mask, good to learn outpainting as well, recommended 0.0 for characters + invert_inpaint_mask_chance: 0.5 + # this will do a normal t2i training step without inpaint when dropped out. REcommended if you want + # your lora to be able to inference with and without inpainting. + inpaint_dropout: 0.5 + # randomly drops out the control image. Dropout recvommended if your want it to work without controls as well. + control_dropout: 0.5 + # does a random inpaint blob. Usually a good idea to keep. Without it, the model will learn to always 100% + # fill the inpaint area with your subject. This is not always a good thing. + inpaint_random_chance: 0.5 + # generates random inpaint blobs if you did not provide an inpaint image for your dataset. Inpaint breaks down fast + # if you are not training with it. Controls are a little more robust and can be left out, + # but when in doubt, always leave this on + do_random_inpainting: false + # does random blurring of the inpaint mask. Helps prevent weird edge artifacts for real workd inpainting. Leave on. + random_blur_mask: true + # applies a small amount of random dialition and restriction to the inpaint mask. Helps with edge artifacts. + # Leave on. + random_dialate_mask: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + + # you can use a single inpaint or single control image on your samples. + # for controls, the ctrl_idx is 1, the images can be any name and image format. + # use either a pose/line/depth image or whatever you are training with. An example is + # - "photo of [trigger] --ctrl_idx 1 --ctrl_img /path/to/control/image.jpg" + + # for an inpainting image, it must be png/webp. Erase the part of the image you want to inpaint + # IMPORTANT! the inpaint images must be ctrl_idx 0 and have .inpaint.{ext} in the name for this to work right. + # - "photo of [trigger] --ctrl_idx 0 --ctrl_img /path/to/inpaint/image.inpaint.png" + + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flex2 + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_flex_24gb.yaml b/ai-toolkit/config/examples/train_lora_flex_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92bf849f57092f4878208b8caf5fdea96dfd1789 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_flex_24gb.yaml @@ -0,0 +1,101 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flex_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions + train: + batch_size: 1 + # IMPORTANT! For Flex, you must bypass the guidance embedder during training + bypass_guidance_embedding: true + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with flex + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flex, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "ostris/Flex.1-alpha" + is_flux: true + quantize: true # run 8bit mixed precision + quantize_kwargs: + exclude: + - "*time_text_embed*" # exclude the time text embedder from quantization + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flex + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_flux_24gb.yaml b/ai-toolkit/config/examples/train_lora_flux_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e29402b2668b23a215f5ddc10d083e55def7c61 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_flux_24gb.yaml @@ -0,0 +1,96 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true # run 8bit mixed precision +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml b/ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d570da84eec4b3a2f890f6275e6e2a676aac5298 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml @@ -0,0 +1,106 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_kontext_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # control path is the input images for kontext for a paired dataset. These are the source images you want to change. + # You can comment this out and only use normal images if you don't have a paired dataset. + # Control images need to match the filenames on the folder path but in + # a different folder. These do not need captions. + control_path: "/path/to/control/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + # Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram. + resolution: [ 512, 768 ] # flux enjoys multiple resolutions + # resolution: [ 512, 768, 1024 ] + train: + batch_size: 1 + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + timestep_type: "weighted" # sigmoid, linear, or weighted. + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + + # ema will smooth out learning, but could slow it down. + + # ema_config: + # use_ema: true + # ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path. This model is gated. + # visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions + # and then you can use this model. + name_or_path: "black-forest-labs/FLUX.1-Kontext-dev" + arch: "flux_kontext" + quantize: true # run 8bit mixed precision +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # the --ctrl_img path is the one loaded to apply the kontext editing to +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg" + - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg" + - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg" + - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg" + - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg" + - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg" + - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg" + - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg" + - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg" + - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml b/ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4aef078d61765e70a9c1820109075af82367d9f --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml @@ -0,0 +1,98 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new bell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training + is_flux: true + quantize: true # run 8bit mixed precision + # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_hidream_48.yaml b/ai-toolkit/config/examples/train_lora_hidream_48.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f915f4c782a5cff4efa0afc00b9d0e58b7561a90 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_hidream_48.yaml @@ -0,0 +1,112 @@ +# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train. +# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM +# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized. +# HiDream has a mixture of experts that may take special training considerations that I do not +# have implemented properly. The current implementation seems to work well for LoRA training, but +# may not be effective for longer training runs. The implementation could change in future updates +# so your results may vary when this happens. + +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_hidream_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + network_kwargs: + # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt. + # proper training of it is not fully implemented + ignore_if_contains: + - "ff_i.experts" + - "ff_i.gate" + save: + dtype: bfloat16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions + train: + batch_size: 1 + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # wont work with hidream + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + timestep_type: shift # sigmoid, shift, linear + optimizer: "adamw8bit" + lr: 2e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Defaults off + ema_config: + use_ema: false + ema_decay: 0.99 + + # will probably need this if gpu supports it for hidream, other dtypes may not work correctly + dtype: bf16 + model: + # the transformer will get grabbed from this hf repo + # warning ONLY train on Full. The dev and fast models are distilled and will break + name_or_path: "HiDream-ai/HiDream-I1-Full" + # the extras will be grabbed from this hf repo. (text encoder, vae) + extras_name_or_path: "HiDream-ai/HiDream-I1-Full" + arch: "hidream" + # both need to be quantized to train on 48GB currently + quantize: true + quantize_te: true + model_kwargs: + # llama is a gated model, It defaults to unsloth version, but you can set the llama path here + llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct" + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_lumina.yaml b/ai-toolkit/config/examples/train_lora_lumina.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5d2d75666ff254f2ef120ea2d497522cd8ba319 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_lumina.yaml @@ -0,0 +1,96 @@ +--- +# This configuration requires 20GB of VRAM or more to operate +job: extension +config: + # this name will be the folder and filename name + name: "my_first_lumina_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps + # performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 2 # how many intermittent saves to keep + save_format: 'diffusers' # 'diffusers' + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + # cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions + train: + batch_size: 1 + + # can be 'sigmoid', 'linear', or 'lumina2_shift' + timestep_type: 'lumina2_shift' + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with lumina2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Alpha-VLLM/Lumina-Image-2.0" + is_lumina2: true # lumina2 architecture + # you can quantize just the Gemma2 text encoder here to save vram + quantize_te: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear." + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4.0 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml b/ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6eb15302dfd23dc6cbc9b2bce422b4eac196fb6f --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml @@ -0,0 +1,94 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_omnigen2_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions + train: + batch_size: 1 + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with omnigen2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + timestep_type: 'sigmoid' # sigmoid, linear, shift + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + + # ema will smooth out learning, but could slow it down. + # ema_config: + # use_ema: true + # ema_decay: 0.99 + + # will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly + dtype: bf16 + model: + name_or_path: "OmniGen2/OmniGen2 + arch: "omnigen2" + quantize_te: true # quantize_only te + # quantize: true # quantize transformer + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # negative prompt, optional + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml b/ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e464020f52b87cb20ffa28739eeaad9c62c17853 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml @@ -0,0 +1,95 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_qwen_image_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # Trigger words will not work when caching text embeddings +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + # default_caption: "a person" # if caching text embeddings, if you dont have captions, this will get cached + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you have a large dataset + # if you OOM, 1024 may be too much, but should work + resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions + train: + batch_size: 1 + # caching text embeddings is required for 24GB + cache_text_embeddings: true + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with qwen image + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Qwen/Qwen-Image" + arch: "qwen_image" + quantize: true + # qtype_te: "qfloat8" Default float8 qquantization + # to use the ARA use the | pipe to point to hf path, or a local path if you have one. + # 3bit is required for 24GB + qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml b/ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..845ac92b600fc2005b0892789e0a18617e7b846d --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml @@ -0,0 +1,105 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_qwen_image_edit_2509_lora_v1" + process: + - type: 'diffusion_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # can do up to 3 control image folders, file names must match target file names, but aspect/size can be different + control_path: + - "/path/to/control/images/folder1" + - "/path/to/control/images/folder2" + - "/path/to/control/images/folder3" + caption_ext: "txt" + # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions + # a trigger word that can be cached with the text embeddings + # trigger_word: "optional trigger word" + train: + batch_size: 1 + # caching text embeddings is required for 32GB + cache_text_embeddings: true + # unload_text_encoder: true + + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + timestep_type: "weighted" + train_unet: true + train_text_encoder: false # probably won't work with qwen image + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Qwen/Qwen-Image-Edit-2509" + arch: "qwen_image_edit_plus" + quantize: true + # to use the ARA use the | pipe to point to hf path, or a local path if you have one. + # 3bit is required for 32GB + qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + # you can provide up to 3 control images here + samples: + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml b/ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81a999d7997848affcfc8f105f540b966d9e1775 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml @@ -0,0 +1,102 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_qwen_image_edit_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # Trigger words will not work when caching text embeddings +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + control_path: "/path/to/control/images/folder" + caption_ext: "txt" + # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions + train: + batch_size: 1 + # caching text embeddings is required for 32GB + cache_text_embeddings: true + + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + timestep_type: "weighted" + train_unet: true + train_text_encoder: false # probably won't work with qwen image + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Qwen/Qwen-Image-Edit" + arch: "qwen_image_edit" + quantize: true + # qtype_te: "qfloat8" Default float8 qquantization + # to use the ARA use the | pipe to point to hf path, or a local path if you have one. + # 3bit is required for 32GB + qtype: "uint3|qwen_image_edit_torchao_uint3.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + samples: + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml b/ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1766c39220a38134c9db8319d2b7e243f8b8f43 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml @@ -0,0 +1,97 @@ +--- +# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE +job: extension +config: + # this name will be the folder and filename name + name: "my_first_sd3l_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 1024 ] + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # May not fully work with SD3 yet + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" + timestep_type: "linear" # linear or sigmoid + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for sd3, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "stabilityai/stable-diffusion-3.5-large" + is_v3: true + quantize: true # run 8bit mixed precision + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml b/ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32babd14ccd2ad89a77924520eeaea2a7abeb094 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml @@ -0,0 +1,101 @@ +# IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot +# support keeping the text encoder on GPU while training with 24GB, so it is only good +# for training on a single prompt, for example a person with a trigger word. +# to train on captions, you need more vran for now. +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_wan21_14b_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # this is probably needed for 24GB cards when offloading TE to CPU + trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time + # it works well for characters, but not as well for "actions" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 632 ] # will be around 480p + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with wan + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + timestep_type: 'sigmoid' + optimizer: "adamw8bit" + lr: 1e-4 + optimizer_params: + weight_decay: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + dtype: bf16 + # required for 24GB cards + # this will encode your trigger word and use those embeddings for every image in the dataset + unload_text_encoder: true + model: + # huggingface model name or path + name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers" + arch: 'wan21' + # these settings will save as much vram as possible + quantize: true + quantize_te: true + low_vram: true + sample: + sampler: "flowmatch" + sample_every: 250 # sample every this many steps + width: 832 + height: 480 + num_frames: 40 + fps: 15 + # samples take a long time. so use them sparingly + # samples will be animated webp files, if you don't see them animated, open in a browser. + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 5 + sample_steps: 30 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml b/ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e75f7ab032ffc7ee41a37eb6a951d4d15a475ec --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml @@ -0,0 +1,90 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_wan21_1b_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time + # it works well for characters, but not as well for "actions" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 632 ] # will be around 480p + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with wan + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + timestep_type: 'sigmoid' + optimizer: "adamw8bit" + lr: 1e-4 + optimizer_params: + weight_decay: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + arch: 'wan21' + quantize_te: true # saves vram + sample: + sampler: "flowmatch" + sample_every: 250 # sample every this many steps + width: 832 + height: 480 + num_frames: 40 + fps: 15 + # samples take a long time. so use them sparingly + # samples will be animated webp files, if you don't see them animated, open in a browser. + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 5 + sample_steps: 30 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml b/ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..966f184f471207519948751931d39dd65c57e622 --- /dev/null +++ b/ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml @@ -0,0 +1,111 @@ +# this example focuses mainly for training Wan2.2 14b on images. It will work for video as well by increasing +# the number of frames in the dataset and samples. Training on and generating video is very VRAM intensive. +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_wan22_14b_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # Use a trigger word if train.unload_text_encoder is true, however, if caching text embeddings, do not use a trigger word + # trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/or/video/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + # number of frames to extract from your video. It will automatically extract them evenly spaced + # set to 1 frame for images + num_frames: 1 + resolution: [ 512, 768, 1024] + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with wan + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + timestep_type: 'linear' + optimizer: "adamw8bit" + lr: 1e-4 + optimizer_params: + weight_decay: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + dtype: bf16 + + # IMPORTANT: this is for Wan 2.2 MOE. It will switch training one stage or the other every this many steps + switch_boundary_every: 10 + + # required for 24GB cards. You must do either unload_text_encoder or cache_text_embeddings but not both + + # this will encode your trigger word and use those embeddings for every image in the dataset, captions will be ignored + # unload_text_encoder: true + + # this will cache all captions in your dataset. + cache_text_embeddings: true + + model: + # huggingface model name or path, this one if bf16, vs the float32 of the official repo + name_or_path: "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16" + arch: 'wan22_14b' + quantize: true + # This will pull and use a custom Accuracy Recovery Adapter to train at 4bit + qtype: "uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + model_kwargs: + # you can train high noise, low noise, or both. With low vram it will automatically unload the one not being trained. + train_high_noise: true + train_low_noise: true + sample: + sampler: "flowmatch" + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + # set to 1 for images + num_frames: 1 + fps: 16 + # samples take a long time. so use them sparingly + # samples will be animated webp files, if you don't see them animated, open in a browser. + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3.5 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/ai-toolkit/config/examples/train_slider.example.yml b/ai-toolkit/config/examples/train_slider.example.yml new file mode 100644 index 0000000000000000000000000000000000000000..b36009175af6b384527a2fb57e0f316caba6049a --- /dev/null +++ b/ai-toolkit/config/examples/train_slider.example.yml @@ -0,0 +1,230 @@ +--- +# This is in yaml format. You can use json if you prefer +# I like both but yaml is easier to write +# Plus it has comments which is nice for documentation +# This is the config I use on my sliders, It is solid and tested +job: train +config: + # the name will be used to create a folder in the output folder + # it will also replace any [name] token in the rest of this config + name: detail_slider_v1 + # folder will be created with name above in folder below + # it can be relative to the project root or absolute + training_folder: "output/LoRA" + device: cuda:0 # cpu, cuda:0, etc + # for tensorboard logging, we will make a subfolder for this job + log_dir: "output/.tensorboard" + # you can stack processes for other jobs, It is not tested with sliders though + # just use one for now + process: + - type: slider # tells runner to run the slider process + # network is the LoRA network for a slider, I recommend to leave this be + network: + # network type lierla is traditional LoRA that works everywhere, only linear layers + type: "lierla" + # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good + linear: 8 + linear_alpha: 4 # Do about half of rank + # training config + train: + # this is also used in sampling. Stick with ddpm unless you know what you are doing + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + # how many steps to train. More is not always better. I rarely go over 1000 + steps: 500 + # I have had good results with 4e-4 to 1e-4 at 500 steps + lr: 2e-4 + # enables gradient checkpoint, saves vram, leave it on + gradient_checkpointing: true + # train the unet. I recommend leaving this true + train_unet: true + # train the text encoder. I don't recommend this unless you have a special use case + # for sliders we are adjusting representation of the concept (unet), + # not the description of it (text encoder) + train_text_encoder: false + # same as from sd-scripts, not fully tested but should speed up training + min_snr_gamma: 5.0 + # just leave unless you know what you are doing + # also supports "dadaptation" but set lr to 1 if you use that, + # but it learns too fast and I don't recommend it + optimizer: "adamw" + # only constant for now + lr_scheduler: "constant" + # we randomly denoise random num of steps form 1 to this number + # while training. Just leave it + max_denoising_steps: 40 + # works great at 1. I do 1 even with my 4090. + # higher may not work right with newer single batch stacking code anyway + batch_size: 1 + # bf16 works best if your GPU supports it (modern) + dtype: bf16 # fp32, bf16, fp16 + # if you have it, use it. It is faster and better + # torch 2.0 doesnt need xformers anymore, only use if you have lower version +# xformers: true + # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX + # although, the way we train sliders is comparative, so it probably won't work anyway + noise_offset: 0.0 +# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL + + # the model to train the LoRA network on + model: + # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt + name_or_path: "runwayml/stable-diffusion-v1-5" + is_v2: false # for v2 models + is_v_pred: false # for v-prediction models (most v2 models) + # has some issues with the dual text encoder and the way we train sliders + # it works bit weights need to probably be higher to see it. + is_xl: false # for SDXL models + + # saving config + save: + dtype: float16 # precision to save. I recommend float16 + save_every: 50 # save every this many steps + # this will remove step counts more than this number + # allows you to save more often in case of a crash without filling up your drive + max_step_saves_to_keep: 2 + + # sampling config + sample: + # must match train.noise_scheduler, this is not used here + # but may be in future and in other processes + sampler: "ddpm" + # sample every this many steps + sample_every: 20 + # image size + width: 512 + height: 512 + # prompts to use for sampling. Do as many as you want, but it slows down training + # pick ones that will best represent the concept you are trying to adjust + # allows some flags after the prompt + # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive + # slide are good tests. will inherit sample.network_multiplier if not set + # --n [string] # negative prompt, will inherit sample.neg if not set + # Only 75 tokens allowed currently + # I like to do a wide positive and negative spread so I can see a good range and stop + # early if the network is braking down + prompts: + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3" + - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5" + - "a golden retriever sitting on a leather couch, --m -5" + - "a golden retriever sitting on a leather couch --m -3" + - "a golden retriever sitting on a leather couch --m 3" + - "a golden retriever sitting on a leather couch --m 5" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3" + - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5" + # negative prompt used on all prompts above as default if they don't have one + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome" + # seed for sampling. 42 is the answer for everything + seed: 42 + # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc + # will start over on next sample_every so s1 is always seed + # works well if you use same prompt but want different results + walk_seed: false + # cfg scale (4 to 10 is good) + guidance_scale: 7 + # sampler steps (20 to 30 is good) + sample_steps: 20 + # default network multiplier for all prompts + # since we are training a slider, I recommend overriding this with --m [number] + # in the prompts above to get both sides of the slider + network_multiplier: 1.0 + + # logging information + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false # probably done need unless you are debugging + + # slider training config, best for last + slider: + # resolutions to train on. [ width, height ]. This is less important for sliders + # as we are not teaching the model anything it doesn't already know + # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1 + # and [ 1024, 1024 ] for sd_xl + # you can do as many as you want here + resolutions: + - [ 512, 512 ] +# - [ 512, 768 ] +# - [ 768, 768 ] + # slider training uses 4 combined steps for a single round. This will do it in one gradient + # step. It is highly optimized and shouldn't take anymore vram than doing without it, + # since we break down batches for gradient accumulation now. so just leave it on. + batch_full_slide: true + # These are the concepts to train on. You can do as many as you want here, + # but they can conflict outweigh each other. Other than experimenting, I recommend + # just doing one for good results + targets: + # target_class is the base concept we are adjusting the representation of + # for example, if we are adjusting the representation of a person, we would use "person" + # if we are adjusting the representation of a cat, we would use "cat" It is not + # a keyword necessarily but what the model understands the concept to represent. + # "person" will affect men, women, children, etc but will not affect cats, dogs, etc + # it is the models base general understanding of the concept and everything it represents + # you can leave it blank to affect everything. In this example, we are adjusting + # detail, so we will leave it blank to affect everything + - target_class: "" + # positive is the prompt for the positive side of the slider. + # It is the concept that will be excited and amplified in the model when we slide the slider + # to the positive side and forgotten / inverted when we slide + # the slider to the negative side. It is generally best to include the target_class in + # the prompt. You want it to be the extreme of what you want to train on. For example, + # if you want to train on fat people, you would use "an extremely fat, morbidly obese person" + # as the prompt. Not just "fat person" + # max 75 tokens for now + positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" + # negative is the prompt for the negative side of the slider and works the same as positive + # it does not necessarily work the same as a negative prompt when generating images + # these need to be polar opposites. + # max 76 tokens for now + negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" + # the loss for this target is multiplied by this number. + # if you are doing more than one target it may be good to set less important ones + # to a lower number like 0.1 so they don't outweigh the primary target + weight: 1.0 + # shuffle the prompts split by the comma. We will run every combination randomly + # this will make the LoRA more robust. You probably want this on unless prompt order + # is important for some reason + shuffle: true + + + # anchors are prompts that we will try to hold on to while training the slider + # these are NOT necessary and can prevent the slider from converging if not done right + # leave them off if you are having issues, but they can help lock the network + # on certain concepts to help prevent catastrophic forgetting + # you want these to generate an image that is not your target_class, but close to it + # is fine as long as it does not directly overlap it. + # For example, if you are training on a person smiling, + # you could use "a person with a face mask" as an anchor. It is a person, the image is the same + # regardless if they are smiling or not, however, the closer the concept is to the target_class + # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually + # for close concepts, you want to be closer to 0.1 or 0.2 + # these will slow down training. I am leaving them off for the demo + +# anchors: +# - prompt: "a woman" +# neg_prompt: "animal" +# # the multiplier applied to the LoRA when this is run. +# # higher will give it more weight but also help keep the lora from collapsing +# multiplier: 1.0 +# - prompt: "a man" +# neg_prompt: "animal" +# multiplier: 1.0 +# - prompt: "a person" +# neg_prompt: "animal" +# multiplier: 1.0 + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/ai-toolkit/dgx_instructions.md b/ai-toolkit/dgx_instructions.md new file mode 100644 index 0000000000000000000000000000000000000000..06198aa8a747473560eb4c221f0c852225219efb --- /dev/null +++ b/ai-toolkit/dgx_instructions.md @@ -0,0 +1,84 @@ +# AI Toolkit by Ostris + +## DGX OS installation instructions + +You need to use Python 3.11 to run AI Toolkit on DGX OS. The easiest way to do this without affecting the system installation of Python is to create a virtual environment with **miniconda**, which allows you to specify the version of Python to use in the environment. + +This guide will assume you have a fresh installation of DGX OS, and will guide you through the installation of all requirements. + +### Installation instructions for DGX OS: + +**1) Get Python 3.11 (via miniconda)** + +Install the latest version of miniconda: +``` +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh +chmod u+x Miniconda3-latest-Linux-aarch64.sh +./Miniconda3-latest-Linux-aarch64.sh +``` + +Restart your bash or ssh session. If miniconda was installed successfully, it will automatically load the 'base' environment by default. If you want to disable this behaviour, run: +``` +conda config --set auto_activate_base false +``` + +Now you can create a Python 3.11 environment for ai-toolkit: +``` +conda create --name ai-toolkit python=3.11 +``` + +Then activate the environment with: + +``` +conda activate ai-toolkit +``` + + +**2) Install PyTorch** + +``` +pip3 install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu130 +``` + + +**3) Install the remaining requirements (dgx_requirements.txt)** + +``` +pip3 install -r dgx_requirements.txt +``` + +### Running the UI on DGX OS: + +Running the UI is not that different from doing it on other systems, however, you need to install the ARM64 version of NodeJS for Linux, which is compatible with the NVIDIA Grace CPU. + + +**1) Install Node.js** + +Download a Linux ARM64 build of Node.js from: https://nodejs.org (for example: https://nodejs.org/dist/v24.11.1/node-v24.11.1-linux-arm64.tar.xz) + +Extract it and add the bin directory to your path. I extracted it to **/opt** and added the following to my ~/.bashrc file: +``` +export PATH=“/opt/node-v24.11.1-linux-arm64/bin:$PATH” +``` + + +**2) Compile and run the Node.js UI** + +Change to the ui directory, then build and run the UI: +``` +cd ui +npm run build_and_start +``` + +If all went well, you’ll be able to access the UI on port 8675 and start training. + + +
+ Troubleshooting issues +If you’re not getting any output when starting a training job from the UI, it’s probably crashing before the process started, the best way to debug these issues is to run the python training script directly (which is normally started by the UI). To do this, set up a training job in the UI, go to the advanced config screen, copy and paste the configuration into a file like train.yaml, then run the training script like this with the conda virtual environment active: + +``` +python run.py path/to/train.yaml +``` +
+
\ No newline at end of file diff --git a/ai-toolkit/dgx_requirements.txt b/ai-toolkit/dgx_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8e507e664b955cf66b2618b599c2250987d6861 --- /dev/null +++ b/ai-toolkit/dgx_requirements.txt @@ -0,0 +1,13 @@ +# You need to use Python 3.11, the easiest way to get this on DGX OS without impacting the system version of Python is to create an environment with miniconda. + +# specific dependency versions needed on DGX OS devices: +scipy==1.16.0 +tifffile==2025.6.11 +imageio==2.37.0 +scikit_image==0.25.2 +clean_fid==0.1.35 +pywavelets==1.9.0 +contourpy==1.3.3 +opencv_python_headless==4.11.0.86 + +-r requirements_base.txt diff --git a/ai-toolkit/docker-compose.yml b/ai-toolkit/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..bdf86aedeb608eff11324e4761e0a8a4584b07bf --- /dev/null +++ b/ai-toolkit/docker-compose.yml @@ -0,0 +1,25 @@ +version: "3.8" + +services: + ai-toolkit: + image: ostris/aitoolkit:latest + restart: unless-stopped + ports: + - "8675:8675" + volumes: + - ~/.cache/huggingface/hub:/root/.cache/huggingface/hub + - ./aitk_db.db:/app/ai-toolkit/aitk_db.db + - ./datasets:/app/ai-toolkit/datasets + - ./output:/app/ai-toolkit/output + - ./config:/app/ai-toolkit/config + environment: + - AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password} + - NODE_ENV=production + - TZ=UTC + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] diff --git a/ai-toolkit/docker/Dockerfile b/ai-toolkit/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e7208f2e97f0ed0dc9b1384c825d36cc2d9c5063 --- /dev/null +++ b/ai-toolkit/docker/Dockerfile @@ -0,0 +1,108 @@ +FROM nvidia/cuda:12.8.1-devel-ubuntu24.04 + +LABEL authors="jaret" + +# Set noninteractive to avoid timezone prompts +ENV DEBIAN_FRONTEND=noninteractive + +# ref https://en.wikipedia.org/wiki/CUDA +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0 10.0 12.0" + +# Install dependencies +RUN apt-get update && apt-get install --no-install-recommends -y \ + git \ + curl \ + build-essential \ + cmake \ + wget \ + python3.12 \ + python3-pip \ + python3-dev \ + python3-setuptools \ + python3-wheel \ + python3-venv \ + ffmpeg \ + tmux \ + htop \ + nvtop \ + python3-opencv \ + openssh-client \ + openssh-server \ + openssl \ + rsync \ + unzip \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Install nodejs +WORKDIR /tmp +RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \ + bash nodesource_setup.sh && \ + apt-get update && \ + apt-get install -y nodejs && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Set aliases for python and pip +RUN ln -s /usr/bin/python3 /usr/bin/python + +# install pytorch before cache bust to avoid redownloading pytorch +RUN pip install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages + +WORKDIR /app/ai-toolkit + +# ---------------------------------------------------------------------------- # +# Dependency layers come BEFORE the source clone so they are only rebuilt (and +# only need to be re-pulled by servers) when the dependency manifests change, +# not on every code change. +# ---------------------------------------------------------------------------- # + +# Install Python dependencies (only re-runs when the requirements files change) +COPY requirements.txt requirements_base.txt /app/ai-toolkit/ +RUN pip install --no-cache-dir --break-system-packages -r requirements.txt && \ + pip install setuptools==69.5.1 --no-cache-dir --break-system-packages + +# Install Node dependencies (only re-runs when package.json / package-lock.json change) +COPY ui/package.json ui/package-lock.json /app/ai-toolkit/ui/ +RUN cd /app/ai-toolkit/ui && npm ci + +# ---------------------------------------------------------------------------- # +# Source code comes LAST. Only this layer (plus the UI build below) is rebuilt +# on a code change, so servers only re-pull the (small) source, not the deps. +# Clone to a temp dir and rsync the source in, preserving the dependency dirs +# already populated above (ui/node_modules) and the manifests already used. +# ---------------------------------------------------------------------------- # +ARG CACHEBUST=1234 +ARG GIT_COMMIT=main +RUN echo "Cache bust: ${CACHEBUST}" && \ + git clone https://github.com/ostris/ai-toolkit.git /tmp/ai-toolkit-src && \ + cd /tmp/ai-toolkit-src && \ + git checkout ${GIT_COMMIT} && \ + rsync -a --delete \ + --exclude 'ui/node_modules' \ + --exclude 'requirements.txt' \ + --exclude 'ui/package.json' \ + --exclude 'ui/package-lock.json' \ + /tmp/ai-toolkit-src/ /app/ai-toolkit/ && \ + rm -rf /tmp/ai-toolkit-src + +# Build UI (re-runs on code change, but reuses the cached node_modules above). +# update_db runs first because it does `prisma generate`, which creates the +# @prisma/client types the TS build needs. In the old layout generate happened +# as a side effect of npm install seeing the schema; now the source arrives +# after npm ci, so run it explicitly before the build. +RUN cd /app/ai-toolkit/ui && \ + npm run update_db && \ + npm run build + +# Expose port (assuming the application runs on port 3000) +EXPOSE 8675 + +WORKDIR / + +COPY docker/start.sh /start.sh +RUN chmod +x /start.sh + +CMD ["/start.sh"] \ No newline at end of file diff --git a/ai-toolkit/docker/start.sh b/ai-toolkit/docker/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..abcbca0500d7d728d954d755cad29bfc3ce56f79 --- /dev/null +++ b/ai-toolkit/docker/start.sh @@ -0,0 +1,70 @@ +#!/bin/bash +set -e # Exit the script if any statement returns a non-true return value + +# ref https://github.com/runpod/containers/blob/main/container-template/start.sh + +# ---------------------------------------------------------------------------- # +# Function Definitions # +# ---------------------------------------------------------------------------- # + + +# Setup ssh +setup_ssh() { + if [[ $PUBLIC_KEY ]]; then + echo "Setting up SSH..." + mkdir -p ~/.ssh + echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys + chmod 700 -R ~/.ssh + + if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then + ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N '' + echo "RSA key fingerprint:" + ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub + fi + + if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then + ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N '' + echo "DSA key fingerprint:" + ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub + fi + + if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then + ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N '' + echo "ECDSA key fingerprint:" + ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub + fi + + if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then + ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N '' + echo "ED25519 key fingerprint:" + ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub + fi + + service ssh start + + echo "SSH host keys:" + for key in /etc/ssh/*.pub; do + echo "Key: $key" + ssh-keygen -lf $key + done + fi +} + +# Export env vars +export_env_vars() { + echo "Exporting environment variables..." + printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment + echo 'source /etc/rp_environment' >> ~/.bashrc +} + +# ---------------------------------------------------------------------------- # +# Main Program # +# ---------------------------------------------------------------------------- # + + +echo "Pod Started" + +setup_ssh +export_env_vars +echo "Starting AI Toolkit UI..." +cd /app/ai-toolkit/ui && npm run start \ No newline at end of file diff --git a/ai-toolkit/extensions/example/ExampleMergeModels.py b/ai-toolkit/extensions/example/ExampleMergeModels.py new file mode 100644 index 0000000000000000000000000000000000000000..162d514c38799b1c5cdc4717e1fa9867a4b35572 --- /dev/null +++ b/ai-toolkit/extensions/example/ExampleMergeModels.py @@ -0,0 +1,129 @@ +import torch +import gc +from collections import OrderedDict +from typing import TYPE_CHECKING +from jobs.process import BaseExtensionProcess +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +from tqdm import tqdm + +# Type check imports. Prevents circular imports +if TYPE_CHECKING: + from jobs import ExtensionJob + + +# extend standard config classes to add weight +class ModelInputConfig(ModelConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.weight = kwargs.get('weight', 1.0) + # overwrite default dtype unless user specifies otherwise + # float 32 will give up better precision on the merging functions + self.dtype: str = kwargs.get('dtype', 'float32') + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +# this is our main class process +class ExampleMergeModels(BaseExtensionProcess): + def __init__( + self, + process_id: int, + job: 'ExtensionJob', + config: OrderedDict + ): + super().__init__(process_id, job, config) + # this is the setup process, do not do process intensive stuff here, just variable setup and + # checking requirements. This is called before the run() function + # no loading models or anything like that, it is just for setting up the process + # all of your process intensive stuff should be done in the run() function + # config will have everything from the process item in the config file + + # convince methods exist on BaseProcess to get config values + # if required is set to true and the value is not found it will throw an error + # you can pass a default value to get_conf() as well if it was not in the config file + # as well as a type to cast the value to + self.save_path = self.get_conf('save_path', required=True) + self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype) + self.device = self.get_conf('device', default='cpu', as_type=torch.device) + + # build models to merge list + models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list) + # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config + # this way you can add methods to it and it is easier to read and code. There are a lot of + # inbuilt config classes located in toolkit.config_modules as well + self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge] + # setup is complete. Don't load anything else here, just setup variables and stuff + + # this is the entire run process be sure to call super().run() first + def run(self): + # always call first + super().run() + print(f"Running process: {self.__class__.__name__}") + + # let's adjust our weights first to normalize them so the total is 1.0 + total_weight = sum([model.weight for model in self.models_to_merge]) + weight_adjust = 1.0 / total_weight + for model in self.models_to_merge: + model.weight *= weight_adjust + + output_model: StableDiffusion = None + # let's do the merge, it is a good idea to use tqdm to show progress + for model_config in tqdm(self.models_to_merge, desc="Merging models"): + # setup model class with our helper class + sd_model = StableDiffusion( + device=self.device, + model_config=model_config, + dtype="float32" + ) + # load the model + sd_model.load_model() + + # adjust the weight of the text encoder + if isinstance(sd_model.text_encoder, list): + # sdxl model + for text_encoder in sd_model.text_encoder: + for key, value in text_encoder.state_dict().items(): + value *= model_config.weight + else: + # normal model + for key, value in sd_model.text_encoder.state_dict().items(): + value *= model_config.weight + # adjust the weights of the unet + for key, value in sd_model.unet.state_dict().items(): + value *= model_config.weight + + if output_model is None: + # use this one as the base + output_model = sd_model + else: + # merge the models + # text encoder + if isinstance(output_model.text_encoder, list): + # sdxl model + for i, text_encoder in enumerate(output_model.text_encoder): + for key, value in text_encoder.state_dict().items(): + value += sd_model.text_encoder[i].state_dict()[key] + else: + # normal model + for key, value in output_model.text_encoder.state_dict().items(): + value += sd_model.text_encoder.state_dict()[key] + # unet + for key, value in output_model.unet.state_dict().items(): + value += sd_model.unet.state_dict()[key] + + # remove the model to free memory + del sd_model + flush() + + # merge loop is done, let's save the model + print(f"Saving merged model to {self.save_path}") + output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype) + print(f"Saved merged model to {self.save_path}") + # do cleanup here + del output_model + flush() diff --git a/ai-toolkit/extensions/example/__init__.py b/ai-toolkit/extensions/example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34f348f1a1c278d7b71d48a60d5ddf909141b7db --- /dev/null +++ b/ai-toolkit/extensions/example/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class ExampleMergeExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "example_merge_extension" + + # name is the name of the extension for printing + name = "Example Merge Extension" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ExampleMergeModels import ExampleMergeModels + return ExampleMergeModels + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ExampleMergeExtension +] diff --git a/ai-toolkit/extensions/example/config/config.example.yaml b/ai-toolkit/extensions/example/config/config.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..abed03fd9e9197e072a73203c67ad7ee976912cd --- /dev/null +++ b/ai-toolkit/extensions/example/config/config.example.yaml @@ -0,0 +1,48 @@ +--- +# Always include at least one example config file to show how to use your extension. +# use plenty of comments so users know how to use it and what everything does + +# all extensions will use this job name +job: extension +config: + name: 'my_awesome_merge' + process: + # Put your example processes here. This will be passed + # to your extension process in the config argument. + # the type MUST match your extension uid + - type: "example_merge_extension" + # save path for the merged model + save_path: "output/merge/[name].safetensors" + # save type + dtype: fp16 + # device to run it on + device: cuda:0 + # input models can only be SD1.x and SD2.x models for this example (currently) + models_to_merge: + # weights are relative, total weights will be normalized + # for example. If you have 2 models with weight 1.0, they will + # both be weighted 0.5. If you have 1 model with weight 1.0 and + # another with weight 2.0, the first will be weighted 1/3 and the + # second will be weighted 2/3 + - name_or_path: "input/model1.safetensors" + weight: 1.0 + - name_or_path: "input/model2.safetensors" + weight: 1.0 + - name_or_path: "input/model3.safetensors" + weight: 0.3 + - name_or_path: "input/model4.safetensors" + weight: 1.0 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py b/ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..58713bd0c570e1240811cefd8e8b1f390346c069 --- /dev/null +++ b/ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py @@ -0,0 +1,256 @@ +import math +import os +import random +from collections import OrderedDict +from typing import List + +import numpy as np +from PIL import Image +from diffusers import T2IAdapter +from diffusers.utils.torch_utils import randn_tensor +from torch.utils.data import DataLoader +from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.train_tools import get_torch_dtype +from controlnet_aux.midas import MidasDetector +from diffusers.utils import load_image +from torchvision.transforms import ToTensor + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + + + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.walk_seed = kwargs.get('walk_seed', False) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.denoise_strength = kwargs.get('denoise_strength', 0.5) + self.trigger_word = kwargs.get('trigger_word', None) + + +class Img2ImgGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.copy_inputs_to = self.get_conf('copy_inputs_to', None) + self.device = self.get_conf('device', 'cuda') + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + print(f"Using device {self.device}") + self.data_loader: DataLoader = None + self.adapter: T2IAdapter = None + + def to_pil(self, img): + # image comes in -1 to 1. convert to a PIL RGB image + img = (img + 1) / 2 + img = img.clamp(0, 1) + img = img[0].permute(1, 2, 0).cpu().numpy() + img = (img * 255).astype(np.uint8) + image = Image.fromarray(img) + return image + + def run(self): + with torch.no_grad(): + super().run() + print("Loading model...") + self.sd.load_model() + device = torch.device(self.device) + + if self.model_config.is_xl: + pipe = StableDiffusionXLImg2ImgPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + ).to(device, dtype=self.torch_dtype) + elif self.model_config.is_pixart: + pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype) + else: + raise NotImplementedError("Only XL models are supported") + pipe.set_progress_bar_config(disable=True) + + # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True) + + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) + + num_batches = len(self.data_loader) + pbar = tqdm(total=num_batches, desc="Generating images") + seed = self.generate_config.seed + # load images from datasets, use tqdm + for i, batch in enumerate(self.data_loader): + batch: DataLoaderBatchDTO = batch + + gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1) + generator = torch.manual_seed(gen_seed) + + file_item: FileItemDTO = batch.file_items[0] + img_path = file_item.path + img_filename = os.path.basename(img_path) + img_filename_no_ext = os.path.splitext(img_filename)[0] + img_filename = img_filename_no_ext + '.' + self.generate_config.ext + output_path = os.path.join(self.output_folder, img_filename) + output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt') + + if self.copy_inputs_to is not None: + output_inputs_path = os.path.join(self.copy_inputs_to, img_filename) + output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt') + else: + output_inputs_path = None + output_inputs_caption_path = None + + caption = batch.get_caption_list()[0] + if self.generate_config.trigger_word is not None: + caption = caption.replace('[trigger]', self.generate_config.trigger_word) + + img: torch.Tensor = batch.tensor.clone() + image = self.to_pil(img) + + # image.save(output_depth_path) + if self.model_config.is_pixart: + pipe: PixArtSigmaPipeline = pipe + + # Encode the full image once + encoded_image = pipe.vae.encode( + pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype)) + if hasattr(encoded_image, "latent_dist"): + latents = encoded_image.latent_dist.sample(generator) + elif hasattr(encoded_image, "latents"): + latents = encoded_image.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + latents = pipe.vae.config.scaling_factor * latents + + # latents = self.sd.encode_images(img) + + # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps) + # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength) + # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0) + # timestep = timestep.to(device, dtype=torch.int32) + # latent = latent.to(device, dtype=self.torch_dtype) + # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype) + # latent = self.sd.add_noise(latent, noise, timestep) + # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:] + batch_size = 1 + num_images_per_prompt = 1 + + shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor, + image.width // pipe.vae_scale_factor) + noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype) + + # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype) + num_inference_steps = self.generate_config.sample_steps + strength = self.generate_config.denoise_strength + # Get timesteps + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + pipe.scheduler.set_timesteps(num_inference_steps, device="cpu") + timesteps = pipe.scheduler.timesteps[t_start:] + timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = pipe.scheduler.add_noise(latents, noise, timestep) + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + latents=latents, + timesteps=timesteps, + width=image.width, + height=image.height, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + guidance_scale=self.generate_config.guidance_scale, + # strength=self.generate_config.denoise_strength, + use_resolution_binning=False, + output_type="np" + ).images[0] + gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8) + gen_images = Image.fromarray(gen_images) + else: + pipe: StableDiffusionXLImg2ImgPipeline = pipe + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + strength=self.generate_config.denoise_strength, + ).images[0] + os.makedirs(os.path.dirname(output_path), exist_ok=True) + gen_images.save(output_path) + + # save caption + with open(output_caption_path, 'w') as f: + f.write(caption) + + if output_inputs_path is not None: + os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True) + image.save(output_inputs_path) + with open(output_inputs_caption_path, 'w') as f: + f.write(caption) + + pbar.update(1) + batch.cleanup() + + pbar.close() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py b/ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..ec19da317230f5a837ebc6561fd4dc7cac2fd946 --- /dev/null +++ b/ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py @@ -0,0 +1,102 @@ +import os +from collections import OrderedDict + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig +from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.train_tools import get_torch_dtype + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class PureLoraGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.device = self.get_conf('device', 'cuda') + self.device_torch = torch.device(self.device) + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + lorm_config = self.get_conf('lorm', None) + self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None + + self.device_state_preset = get_train_sd_device_state_preset( + device=torch.device(self.device), + ) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + + def run(self): + super().run() + print("Loading model...") + with torch.no_grad(): + self.sd.load_model() + self.sd.unet.eval() + self.sd.unet.to(self.device_torch) + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + te.to(self.device_torch) + else: + self.sd.text_encoder.eval() + self.sd.to(self.device_torch) + + print(f"Converting to LoRM UNet") + # replace the unet with LoRMUnet + convert_diffusers_unet_to_lorm( + self.sd.unet, + config=self.lorm_config, + ) + + sample_folder = os.path.join(self.output_folder) + gen_img_config_list = [] + + sample_config = self.generate_config + start_seed = sample_config.seed + current_seed = start_seed + for i in range(len(sample_config.prompts)): + if sample_config.walk_seed: + current_seed = start_seed + i + + filename = f"[time]_[count].{self.generate_config.ext}" + output_path = os.path.join(sample_folder, filename) + prompt = sample_config.prompts[i] + extra_args = {} + gen_img_config_list.append(GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + width=sample_config.width, + height=sample_config.height, + negative_prompt=sample_config.neg, + seed=current_seed, + guidance_scale=sample_config.guidance_scale, + guidance_rescale=sample_config.guidance_rescale, + num_inference_steps=sample_config.sample_steps, + network_multiplier=sample_config.network_multiplier, + output_path=output_path, + output_ext=sample_config.ext, + adapter_conditioning_scale=sample_config.adapter_conditioning_scale, + **extra_args + )) + + # send to be generated + self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py b/ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..19e3b6e55786cde2ddcebfdc60230223dc443b62 --- /dev/null +++ b/ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py @@ -0,0 +1,212 @@ +import os +import random +from collections import OrderedDict +from typing import List + +import numpy as np +from PIL import Image +from diffusers import T2IAdapter +from torch.utils.data import DataLoader +from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.train_tools import get_torch_dtype +from controlnet_aux.midas import MidasDetector +from diffusers.utils import load_image + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.walk_seed = kwargs.get('walk_seed', False) + self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.prompt_2 = kwargs.get('prompt_2', None) + self.neg_2 = kwargs.get('neg_2', None) + self.prompts = kwargs.get('prompts', None) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext = kwargs.get('ext', 'png') + self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) + if kwargs.get('shuffle', False): + # shuffle the prompts + random.shuffle(self.prompts) + + +class ReferenceGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.device = self.get_conf('device', 'cuda') + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.params = [] + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + print(f"Using device {self.device}") + self.data_loader: DataLoader = None + self.adapter: T2IAdapter = None + + def run(self): + super().run() + print("Loading model...") + self.sd.load_model() + device = torch.device(self.device) + + if self.generate_config.t2i_adapter_path is not None: + self.adapter = T2IAdapter.from_pretrained( + self.generate_config.t2i_adapter_path, + torch_dtype=self.torch_dtype, + varient="fp16" + ).to(device) + + midas_depth = MidasDetector.from_pretrained( + "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" + ).to(device) + + if self.model_config.is_xl: + pipe = StableDiffusionXLAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + adapter=self.adapter, + ).to(device, dtype=self.torch_dtype) + else: + pipe = StableDiffusionAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder, + tokenizer=self.sd.tokenizer, + scheduler=get_sampler(self.generate_config.sampler), + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + adapter=self.adapter, + ).to(device, dtype=self.torch_dtype) + pipe.set_progress_bar_config(disable=True) + + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True) + + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) + + num_batches = len(self.data_loader) + pbar = tqdm(total=num_batches, desc="Generating images") + seed = self.generate_config.seed + # load images from datasets, use tqdm + for i, batch in enumerate(self.data_loader): + batch: DataLoaderBatchDTO = batch + + file_item: FileItemDTO = batch.file_items[0] + img_path = file_item.path + img_filename = os.path.basename(img_path) + img_filename_no_ext = os.path.splitext(img_filename)[0] + output_path = os.path.join(self.output_folder, img_filename) + output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt') + output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png') + + caption = batch.get_caption_list()[0] + + img: torch.Tensor = batch.tensor.clone() + # image comes in -1 to 1. convert to a PIL RGB image + img = (img + 1) / 2 + img = img.clamp(0, 1) + img = img[0].permute(1, 2, 0).cpu().numpy() + img = (img * 255).astype(np.uint8) + image = Image.fromarray(img) + + width, height = image.size + min_res = min(width, height) + + if self.generate_config.walk_seed: + seed = seed + 1 + + if self.generate_config.seed == -1: + # random + seed = random.randint(0, 1000000) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # generate depth map + image = midas_depth( + image, + detect_resolution=min_res, # do 512 ? + image_resolution=min_res + ) + + # image.save(output_depth_path) + + gen_images = pipe( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale, + guidance_scale=self.generate_config.guidance_scale, + ).images[0] + os.makedirs(os.path.dirname(output_path), exist_ok=True) + gen_images.save(output_path) + + # save caption + with open(output_caption_path, 'w') as f: + f.write(caption) + + pbar.update(1) + batch.cleanup() + + pbar.close() + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/ai-toolkit/extensions_built_in/advanced_generator/__init__.py b/ai-toolkit/extensions_built_in/advanced_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65811655a991c84551d2915c7b04fa66c0d8fbaa --- /dev/null +++ b/ai-toolkit/extensions_built_in/advanced_generator/__init__.py @@ -0,0 +1,59 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class AdvancedReferenceGeneratorExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "reference_generator" + + # name is the name of the extension for printing + name = "Reference Generator" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ReferenceGenerator import ReferenceGenerator + return ReferenceGenerator + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class PureLoraGenerator(Extension): + # uid must be unique, it is how the extension is identified + uid = "pure_lora_generator" + + # name is the name of the extension for printing + name = "Pure LoRA Generator" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .PureLoraGenerator import PureLoraGenerator + return PureLoraGenerator + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class Img2ImgGeneratorExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "batch_img2img" + + # name is the name of the extension for printing + name = "Img2ImgGeneratorExtension" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .Img2ImgGenerator import Img2ImgGenerator + return Img2ImgGenerator + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension +] diff --git a/ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml b/ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793d5d55b9a282d58b53f0e6fafba0b5aa66f7af --- /dev/null +++ b/ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/ai-toolkit/extensions_built_in/audio_models/__init__.py b/ai-toolkit/extensions_built_in/audio_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b01e2655877273b22e2d5164c0233dcbe3a46ccb --- /dev/null +++ b/ai-toolkit/extensions_built_in/audio_models/__init__.py @@ -0,0 +1,7 @@ +from .ace_step import AceStep15Model, AceStep15XLModel + +AI_TOOLKIT_MODELS = [ + # put a list of models here + AceStep15Model, + AceStep15XLModel, +] diff --git a/ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py b/ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b24a377057ddcbd573f5d5c0dd245160bfc78c6 --- /dev/null +++ b/ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py @@ -0,0 +1 @@ +from .ace_step_15_model import AceStep15Model, AceStep15XLModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py b/ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py new file mode 100644 index 0000000000000000000000000000000000000000..430f82ba700491d08f46920b41c088b94a11bf7b --- /dev/null +++ b/ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py @@ -0,0 +1,335 @@ +import json +import os +from typing import List, Optional +import huggingface_hub +import torch +from safetensors.torch import load_file, save_file +from extensions_built_in.audio_models.base_audio_model import BaseAudioModel +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.util.quantize import get_qtype, quantize, quantize_model + +from optimum.quanto import freeze +from .src.model import ( + AceStep15, + OobleckVAE, + TextEncoder, + get_silence_latent, + load_models, +) +from transformers import AutoTokenizer +from .src.pipeline import AceStep15Pipeline + +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": False, +} + +def to_number(str_or_number, default): + if isinstance(str_or_number, (int, float)): + return str_or_number + if str_or_number is None: + return default + if str_or_number == "": + return default + try: + return float(str_or_number) + except ValueError: + try: + return int(str_or_number) + except ValueError as e: + raise ValueError(f"Could not convert {str_or_number} to a number") from e + + +def parse_ace_step_caption(text): + """Parse a tagged caption file back into a dict.""" + import re + + def tag(name): + m = re.search(rf"<{name}>(.*?)", text, re.DOTALL) + return m.group(1).strip() if m else "" + + return { + "caption": tag("CAPTION"), + "lyrics": tag("LYRICS"), + "bpm": to_number(tag("BPM"), 120), + "keyscale": tag("KEYSCALE"), + "timesignature": tag("TIMESIGNATURE"), + "duration": to_number(tag("DURATION"), 1.0), + "language": tag("LANGUAGE"), + } + + +class AceStep15Model(BaseAudioModel): + arch = "ace_step_15" + sample_rate = 48000 + + def __init__( + self, + device, + model_config, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + # self.target_lora_modules = ['AceStep15'] + self.target_lora_modules = ["DiTModel"] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def load_model(self): + dtype = self.torch_dtype + device = self.device_torch + + model_path = self.model_config.name_or_path + + if not os.path.exists(model_path): + # assume it is a hf repo like org/repo/filename.safetensors + path_parts = model_path.split("/") + if len(path_parts) != 3: + raise ValueError( + f"Model path {model_path} does not exist and is not a valid Hugging Face repo path" + ) + model_path = huggingface_hub.hf_hub_download( + repo_id=f"{path_parts[0]}/{path_parts[1]}", + filename=path_parts[2], + ) + # load the models from the single safetensors file + load_device = device + if self.model_config.low_vram: + load_device = "cpu" + + models = load_models(model_path, device=load_device, dtype=dtype) + + self.model = models["model"] + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + # quantize_model(self, self.model.decoder) + quantize(self.model, weights=get_qtype(self.model_config.qtype)) + freeze(self.model) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + self.model.to("cpu") + + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + raise NotImplementedError("Layer offloading not yet implemented for AceStep15Model") + + self.text_encoder = models["text_encoder"] + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(self.text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(self.text_encoder) + flush() + + self.vae = models["vae"] + + # move back to device + self.model.to(device) + self.text_encoder.to(device) + self.vae.to(device) + self.tokenizer = models["tokenizer"] + + self.pipeline = AceStep15Pipeline( + transformer=self.model, + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=self.get_train_scheduler(), + ) + if self.model_config.low_vram: + self.pipeline.do_tiled_decoding = True + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + # we need the encoder from the model + if self.model.encoder.device == torch.device("cpu"): + self.model.encoder.to(self.device_torch) + + # the prompt should be json as a string. Try to parse it. + json_prompts = [] + for p in prompts: + try: + json_prompts.append(parse_ace_step_caption(p)) + except json.JSONDecodeError: + raise ValueError( + f"Prompt {p} is not a valid JSON string. Prompts must be JSON for this model" + ) + + if self.pipeline.text_encoder.device == torch.device("cpu"): + self.pipeline.text_encoder.to(self.device_torch) + + device = self.text_encoder.device + dtype = self.text_encoder.dtype + + batch_pe = None + # TODO not sure this will allow for proper batching + + for json_prompt in json_prompts: + prompt = json_prompt.get("caption", "") + lyrics = json_prompt.get("lyrics", "") + bpm = json_prompt.get("bpm", 120) + key = json_prompt.get("key", "C") + time_sig = json_prompt.get("time_sig", "4/4") + duration = json_prompt.get("duration", 10) + duration = int(duration) if isinstance(duration, (int, float)) else 10 + language = json_prompt.get("language", "en") + + text_embeddings, text_mask, lyric_embeddings, lyric_mask = ( + self.pipeline.get_text_embedings( + prompt, lyrics, bpm, key, time_sig, duration, language + ) + ) + latent_len = int(duration * self.pipeline.LATENT_RATE) + # Silence as source latent [1, 64, T] -> [1, T, 64] for DiT + sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T] + src = sil.transpose(1, 2) # [1, T, 64] + chunk_masks = torch.ones_like(src) + + # Reference audio (silence) + ref = sil[:, :, :750].transpose(1, 2) # [1, 750, 64] + ref_order = torch.zeros(1, device=device, dtype=torch.long) + enc_h, enc_m, _ = self.pipeline.transformer.prepare_condition( + text_embeddings, + text_mask, + lyric_embeddings, + lyric_mask, + ref, + ref_order, + src, + chunk_masks, + ) + + pe = PromptEmbeds(enc_h, attention_mask=enc_m) + if batch_pe is None: + batch_pe = pe + else: + batch_pe = concat_prompt_embeds(batch_pe, pe) + return batch_pe + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def get_generation_pipeline(self): + return self.pipeline + + def generate_single_audio( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + # make sure gen config is setup for audio + if gen_config.output_ext not in ['mp3', 'wav']: + gen_config.output_ext = 'mp3' + prompt = gen_config.prompt + json_prompt = parse_ace_step_caption(prompt) + prompt = json_prompt.get("caption", "") + lyrics = json_prompt.get("lyrics", "") + bpm = json_prompt.get("bpm", 120) + key = json_prompt.get("key", "C") + time_sig = json_prompt.get("time_sig", "4/4") + duration = json_prompt.get("duration", 0) + language = json_prompt.get("language", "en") + + output = self.pipeline( + prompt=None, # we are passing in the embeds directly, so no need for a prompt + encoder_embeddings=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.torch_dtype), + encoder_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.bool), + num_inference_steps=gen_config.num_inference_steps, + duration=duration, + generator=generator, + bpm=bpm, + key=key, + time_sig=time_sig, + language=language, + guidance_scale=gen_config.guidance_scale, + ) + return output + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, #(1, 300, 64) + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + if self.model.decoder.device == torch.device("cpu"): + self.model.decoder.to(self.device_torch) + with torch.no_grad(): + model: AceStep15 = self.model + tt = timestep.to(self.device_torch, dtype=torch.long) / 1000 + latent_len = latent_model_input.shape[1] + device = self.device_torch + dtype = self.torch_dtype + attn = torch.ones(1, latent_len, device=device, dtype=dtype) + + # build context from silence latent matching the actual input length + sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T] + src = sil.transpose(1, 2) # [1, T, 64] + chunk_masks = torch.ones_like(src) + context = torch.cat([src, chunk_masks], dim=-1) # [1, T, 128] + + pred = model.decoder( + x=latent_model_input.detach(), + timestep=tt.detach(), + timestep_r=tt.detach(), + attention_mask=attn.detach(), + enc_h=text_embeddings.text_embeds.to(self.device_torch, dtype=self.torch_dtype).detach(), + enc_m=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.bool).detach(), + context=context.detach(), + ) + return pred + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def encode_audio(self, audio_tensor: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.device_torch + if dtype is None: + dtype = self.torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + output = self.vae.encode(audio_tensor.to(device=device, dtype=dtype)) + # transpose from [B, 64, T] to [B, T, 64] for DiT + output = output.transpose(1, 2).contiguous() + return output + + +class AceStep15XLModel(AceStep15Model): + arch = "ace_step_15_xl" diff --git a/ai-toolkit/extensions_built_in/audio_models/ace_step/src/__init__.py b/ai-toolkit/extensions_built_in/audio_models/ace_step/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/audio_models/ace_step/src/model.py b/ai-toolkit/extensions_built_in/audio_models/ace_step/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf9ffaaea9729f784a3c6f94b7df8607e5ca920 --- /dev/null +++ b/ai-toolkit/extensions_built_in/audio_models/ace_step/src/model.py @@ -0,0 +1,1570 @@ +#!/usr/bin/env python3 +""" +ACE-Step v1.5 — Standalone single-file inference. + +Generates music from text + lyrics. All model code inlined — no project imports, +no trust_remote_code. Uses ComfyUI-style architecture for AIO checkpoint compat. + +Requirements: + pip install torch torchaudio transformers safetensors + +Usage: + python simple_inference.py --prompt "indie folk, warm female vocal, 100 bpm" \ + --lyrics "[Verse]\\nSunlight through the window pane" --duration 30 +""" + +import argparse +import math +import os +import time + +import torch +import torch.nn.functional as F +import torchaudio +from safetensors.torch import load_file +from torch import nn +from transformers import AutoTokenizer +import torch.utils.checkpoint as ckpt + +# ═══════════════════════════════════════════════════════════════════════════════ +# Constants +# ═══════════════════════════════════════════════════════════════════════════════ + +MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") +MODEL_PATHS = { + "base": os.path.join(MODELS_DIR, "ace_step_1.5_xl_base_aio.safetensors"), + "turbo": os.path.join(MODELS_DIR, "ace_step_1.5_turbo_aio.safetensors"), +} +SAMPLE_RATE = 48000 +LATENT_RATE = 25 # 48000 / 1920 + +SFT_PROMPT = """# Instruction +{instruction} + +# Caption +{caption} + +# Metas +{metas}<|endoftext|> +""" + +TURBO_TIMESTEPS = { + 1.0: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125], + 2.0: [1.0, 0.933, 0.857, 0.769, 0.667, 0.545, 0.4, 0.222], + 3.0: [ + 1.0, + 0.9545454545454546, + 0.9, + 0.8333333333333334, + 0.75, + 0.6428571428571429, + 0.5, + 0.3, + ], +} + + +def compute_timesteps(num_steps, shift=3.0): + """Compute flow-matching timestep schedule with shifting.""" + import numpy as np + + sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:-1] # exclude final 0 + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + return sigmas.tolist() + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Silence latent (hardcoded, from ComfyUI) +# ═══════════════════════════════════════════════════════════════════════════════ + + +def get_silence_latent(length, device, dtype=torch.bfloat16): + head = torch.tensor( + [ + [ + [ + 0.5707, + 0.0982, + 0.6909, + -0.5658, + 0.6266, + 0.6996, + -0.1365, + -0.1291, + -0.0776, + -0.1171, + -0.2743, + -0.8422, + -0.1168, + 1.5539, + -4.6936, + 0.7436, + -1.1846, + -0.2637, + 0.6933, + -6.7266, + 0.0966, + -0.1187, + -0.3501, + -1.1736, + 0.0587, + -2.0517, + -1.3651, + 0.7508, + -0.2490, + -1.3548, + -0.1290, + -0.7261, + 1.1132, + -0.3249, + 0.2337, + 0.3004, + 0.6605, + -0.0298, + -0.1989, + -0.4041, + 0.2843, + -1.0963, + -0.5519, + 0.2639, + -1.0436, + -0.1183, + 0.0640, + 0.4460, + -1.1001, + -0.6172, + -1.3241, + 1.1379, + 0.5623, + -0.1507, + -0.1963, + -0.4742, + -2.4697, + 0.5302, + 0.5381, + 0.4636, + -0.1782, + -0.0687, + 1.0333, + 0.4202, + ], + [ + 0.3040, + -0.1367, + 0.6200, + 0.0665, + -0.0642, + 0.4655, + -0.1187, + -0.0440, + 0.2941, + -0.2753, + 0.0173, + -0.2421, + -0.0147, + 1.5603, + -2.7025, + 0.7907, + -0.9736, + -0.0682, + 0.1294, + -5.0707, + -0.2167, + 0.3302, + -0.1513, + -0.8100, + -0.3894, + -0.2884, + -0.3149, + 0.8660, + -0.3817, + -1.7061, + 0.5824, + -0.4840, + 0.6938, + 0.1859, + 0.1753, + 0.3081, + 0.0195, + 0.1403, + -0.0754, + -0.2091, + 0.1251, + -0.1578, + -0.4968, + -0.1052, + -0.4554, + -0.0320, + 0.1284, + 0.4974, + -1.1889, + -0.0344, + -0.8313, + 0.2953, + 0.5445, + -0.6249, + -0.1595, + -0.0682, + -3.1412, + 0.0484, + 0.4153, + 0.8260, + -0.1526, + -0.0625, + 0.5366, + 0.8473, + ], + [ + 5.3524e-02, + -1.7534e-01, + 5.4443e-01, + -4.3501e-01, + -2.1317e-03, + 3.7200e-01, + -4.0143e-03, + -1.5516e-01, + -1.2968e-01, + -1.5375e-01, + -7.7107e-02, + -2.0593e-01, + -3.2780e-01, + 1.5142e00, + -2.6101e00, + 5.8698e-01, + -1.2716e00, + -2.4773e-01, + -2.7933e-02, + -5.0799e00, + 1.1601e-01, + 4.0987e-01, + -2.2030e-02, + -6.6495e-01, + -2.0995e-01, + -6.3474e-01, + -1.5893e-01, + 8.2745e-01, + -2.2992e-01, + -1.6816e00, + 5.4440e-01, + -4.9579e-01, + 5.5128e-01, + 3.0477e-01, + 8.3052e-02, + -6.1782e-02, + 5.9036e-03, + 2.9553e-01, + -8.0645e-02, + -1.0060e-01, + 1.9144e-01, + -3.8124e-01, + -7.2949e-01, + 2.4520e-02, + -5.0814e-01, + 2.3977e-01, + 9.2943e-02, + 3.9256e-01, + -1.1993e00, + -3.2752e-01, + -7.2707e-01, + 2.9476e-01, + 4.3542e-01, + -8.8597e-01, + -4.1686e-01, + -8.5390e-02, + -2.9018e00, + 6.4988e-02, + 5.3945e-01, + 9.1988e-01, + 5.8762e-02, + -7.0098e-02, + 6.4772e-01, + 8.9118e-01, + ], + [ + -3.2225e-02, + -1.3195e-01, + 5.6411e-01, + -5.4766e-01, + -5.2170e-03, + 3.1425e-01, + -5.4367e-02, + -1.9419e-01, + -1.3059e-01, + -1.3660e-01, + -9.0984e-02, + -1.9540e-01, + -2.5590e-01, + 1.5440e00, + -2.6349e00, + 6.8273e-01, + -1.2532e00, + -1.9810e-01, + -2.2793e-02, + -5.0506e00, + 1.8818e-01, + 5.0109e-01, + 7.3546e-03, + -6.8771e-01, + -3.0676e-01, + -7.3257e-01, + -1.6687e-01, + 9.2232e-01, + -1.8987e-01, + -1.7267e00, + 5.3355e-01, + -5.3179e-01, + 4.4953e-01, + 2.8820e-01, + 1.3012e-01, + -2.0943e-01, + -1.1348e-01, + 3.3929e-01, + -1.5069e-01, + -1.2919e-01, + 1.8929e-01, + -3.6166e-01, + -8.0756e-01, + 6.6387e-02, + -5.8867e-01, + 1.6978e-01, + 1.0134e-01, + 3.3877e-01, + -1.2133e00, + -3.2492e-01, + -8.1237e-01, + 3.8101e-01, + 4.3765e-01, + -8.0596e-01, + -4.4531e-01, + -4.7513e-02, + -2.9266e00, + 1.1741e-03, + 4.5123e-01, + 9.3075e-01, + 5.3688e-02, + -1.9621e-01, + 6.4530e-01, + 9.3870e-01, + ], + ] + ], + device=device, + ).movedim(-1, 1) + body = ( + torch.tensor( + [ + [ + [ + -1.3672e-01, + -1.5820e-01, + 5.8594e-01, + -5.7422e-01, + 3.0273e-02, + 2.7930e-01, + -2.5940e-03, + -2.0703e-01, + -1.6113e-01, + -1.4746e-01, + -2.7710e-02, + -1.8066e-01, + -2.9688e-01, + 1.6016e00, + -2.6719e00, + 7.7734e-01, + -1.3516e00, + -1.9434e-01, + -7.1289e-02, + -5.0938e00, + 2.4316e-01, + 4.7266e-01, + 4.6387e-02, + -6.6406e-01, + -2.1973e-01, + -6.7578e-01, + -1.5723e-01, + 9.5312e-01, + -2.0020e-01, + -1.7109e00, + 5.8984e-01, + -5.7422e-01, + 5.1562e-01, + 2.8320e-01, + 1.4551e-01, + -1.8750e-01, + -5.9814e-02, + 3.6719e-01, + -1.0059e-01, + -1.5723e-01, + 2.0605e-01, + -4.3359e-01, + -8.2812e-01, + 4.5654e-02, + -6.6016e-01, + 1.4844e-01, + 9.4727e-02, + 3.8477e-01, + -1.2578e00, + -3.3203e-01, + -8.5547e-01, + 4.3359e-01, + 4.2383e-01, + -8.9453e-01, + -5.0391e-01, + -5.6152e-02, + -2.9219e00, + -2.4658e-02, + 5.0391e-01, + 9.8438e-01, + 7.2754e-02, + -2.1582e-01, + 6.3672e-01, + 1.0000e00, + ] + ] + ], + device=device, + ) + .movedim(-1, 1) + .repeat(1, 1, length) + ) + body[:, :, : head.shape[-1]] = head + return body.to(dtype) # [1, 64, T] + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Helpers +# ═══════════════════════════════════════════════════════════════════════════════ + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base=1000000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._cos = None + self._sin = None + self._cached_len = 0 + + def _build_cache(self, seq_len, device, dtype): + if ( + seq_len <= self._cached_len + and self._cos is not None + and self._cos.device == device + ): + return + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos = emb.cos().to(dtype) + self._sin = emb.sin().to(dtype) + self._cached_len = seq_len + + def forward(self, x, seq_len): + self._build_cache(seq_len, x.device, x.dtype) + return self._cos[:seq_len], self._sin[:seq_len] + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary(q, k, cos, sin): + cos, sin = cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0) + return (q * cos + rotate_half(q) * sin), (k * cos + rotate_half(k) * sin) + + +class MLP(nn.Module): + def __init__(self, hidden, inter): + super().__init__() + self.gate_proj = nn.Linear(hidden, inter, bias=False) + self.up_proj = nn.Linear(hidden, inter, bias=False) + self.down_proj = nn.Linear(inter, hidden, bias=False) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +def pack_sequences(h1, h2, m1, m2): + h = torch.cat([h1, h2], dim=1) + if m1 is not None and m2 is not None: + m = torch.cat([m1, m2], dim=1) + B, L, D = h.shape + idx = m.argsort(dim=1, descending=True, stable=True) + h = torch.gather(h, 1, idx.unsqueeze(-1).expand(B, L, D)) + lengths = m.sum(dim=1) + m = torch.arange(L, device=h.device).unsqueeze(0) < lengths.unsqueeze(1) + else: + m = None + return h, m + + +def timestep_embedding(t, dim, scale=1000, max_period=10000): + t = t * scale + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(half, dtype=torch.float32, device=t.device) + / half + ) + args = t[:, None].float() * freqs[None] + return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# DiT model components (ComfyUI-style, matches AIO weight keys) +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TimestepEmbed(nn.Module): + def __init__(self, hidden): + super().__init__() + self.linear_1 = nn.Linear(256, hidden) + self.act1 = nn.SiLU() + self.linear_2 = nn.Linear(hidden, hidden) + self.act2 = nn.SiLU() + self.time_proj = nn.Linear(hidden, hidden * 6) + self.scale = 1000 + + def forward(self, t, dtype=None): + emb = timestep_embedding(t, 256, self.scale) + temb = self.act1(self.linear_1(emb.to(dtype=dtype))) + temb = self.linear_2(temb) + proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1]) + return temb, proj + + +class Attention(nn.Module): + def __init__( + self, + hidden, + num_heads, + num_kv, + head_dim, + eps=1e-6, + is_cross=False, + sliding_window=None, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv = num_kv + self.head_dim = head_dim + self.is_cross = is_cross + self.sliding_window = sliding_window + self.q_proj = nn.Linear(hidden, num_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden, num_kv * head_dim, bias=False) + self.v_proj = nn.Linear(hidden, num_kv * head_dim, bias=False) + self.o_proj = nn.Linear(num_heads * head_dim, hidden, bias=False) + self.q_norm = RMSNorm(head_dim, eps) + self.k_norm = RMSNorm(head_dim, eps) + + def forward(self, x, encoder_hidden_states=None, position_embeddings=None): + B, L, _ = x.shape + q = self.q_norm( + self.q_proj(x).view(B, L, self.num_heads, self.head_dim) + ).transpose(1, 2) + + src = ( + encoder_hidden_states + if (self.is_cross and encoder_hidden_states is not None) + else x + ) + sL = src.shape[1] + k = self.k_norm( + self.k_proj(src).view(B, sL, self.num_kv, self.head_dim) + ).transpose(1, 2) + v = self.v_proj(src).view(B, sL, self.num_kv, self.head_dim).transpose(1, 2) + + if position_embeddings is not None and not ( + self.is_cross and encoder_hidden_states is not None + ): + q, k = apply_rotary(q, k, *position_embeddings) + + n_rep = self.num_heads // self.num_kv + if n_rep > 1: + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + + attn_bias = None + if self.sliding_window is not None and not self.is_cross: + idx = torch.arange(L, device=q.device) + in_win = ( + torch.abs(idx.unsqueeze(1) - idx.unsqueeze(0)) <= self.sliding_window + ) + attn_bias = torch.zeros(L, sL, device=q.device, dtype=q.dtype) + attn_bias.masked_fill_(~in_win, torch.finfo(q.dtype).min) + attn_bias = attn_bias.unsqueeze(0).unsqueeze(0) + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + return self.o_proj(out.transpose(1, 2).reshape(B, L, -1)) + + +class EncoderLayer(nn.Module): + def __init__(self, hidden, heads, kv, head_dim, inter, eps=1e-6): + super().__init__() + self.self_attn = Attention(hidden, heads, kv, head_dim, eps) + self.input_layernorm = RMSNorm(hidden, eps) + self.post_attention_layernorm = RMSNorm(hidden, eps) + self.mlp = MLP(hidden, inter) + + def forward(self, x, position_embeddings): + x = x + self.self_attn( + self.input_layernorm(x), position_embeddings=position_embeddings + ) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + +class DiTLayer(nn.Module): + def __init__( + self, hidden, heads, kv, head_dim, inter, eps=1e-6, sliding_window=None + ): + super().__init__() + self.self_attn_norm = RMSNorm(hidden, eps) + self.self_attn = Attention( + hidden, heads, kv, head_dim, eps, sliding_window=sliding_window + ) + self.cross_attn_norm = RMSNorm(hidden, eps) + self.cross_attn = Attention(hidden, heads, kv, head_dim, eps, is_cross=True) + self.mlp_norm = RMSNorm(hidden, eps) + self.mlp = MLP(hidden, inter) + self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden)) + + def forward(self, x, temb, enc, position_embeddings): + s_msa, sc_msa, g_msa, s_mlp, sc_mlp, g_mlp = ( + self.scale_shift_table.to(temb) + temb + ).chunk(6, dim=1) + x = ( + x + + self.self_attn( + self.self_attn_norm(x) * (1 + sc_msa) + s_msa, + position_embeddings=position_embeddings, + ) + * g_msa + ) + x = x + self.cross_attn(self.cross_attn_norm(x), encoder_hidden_states=enc) + x = x + self.mlp(self.mlp_norm(x) * (1 + sc_mlp) + s_mlp) * g_mlp + return x + + +# ── Encoders ── + + +class LyricEncoder(nn.Module): + def __init__( + self, text_dim, hidden, n_layers, heads, kv, head_dim, inter, eps=1e-6 + ): + super().__init__() + self.embed_tokens = nn.Linear(text_dim, hidden) + self.norm = RMSNorm(hidden, eps) + self.rotary_emb = RotaryEmbedding(head_dim) + self.layers = nn.ModuleList( + [ + EncoderLayer(hidden, heads, kv, head_dim, inter, eps) + for _ in range(n_layers) + ] + ) + + def forward(self, embeds): + x = self.embed_tokens(embeds) + cos, sin = self.rotary_emb(x, x.shape[1]) + for layer in self.layers: + x = layer(x, (cos, sin)) + return self.norm(x) + + +class TimbreEncoder(nn.Module): + def __init__( + self, timbre_dim, hidden, n_layers, heads, kv, head_dim, inter, eps=1e-6 + ): + super().__init__() + self.embed_tokens = nn.Linear(timbre_dim, hidden) + self.norm = RMSNorm(hidden, eps) + self.rotary_emb = RotaryEmbedding(head_dim) + self.layers = nn.ModuleList( + [ + EncoderLayer(hidden, heads, kv, head_dim, inter, eps) + for _ in range(n_layers) + ] + ) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden)) + + def forward(self, packed, order_mask): + x = self.embed_tokens(packed) + cos, sin = self.rotary_emb(x, x.shape[1]) + for layer in self.layers: + x = layer(x, (cos, sin)) + x = self.norm(x) + cls = x[:, 0, :] + # Unpack to batch + N, D = cls.shape + B = int(order_mask.max().item() + 1) + counts = torch.bincount(order_mask, minlength=B) + mc = counts.max().item() + result = torch.zeros(B, mc, D, device=cls.device, dtype=cls.dtype) + mask = torch.zeros(B, mc, device=cls.device, dtype=torch.long) + for i in range(N): + b = order_mask[i].item() + pos = (order_mask[:i] == b).sum().item() + result[b, pos] = cls[i] + mask[b, pos] = 1 + return result, mask + + +class ConditionEncoder(nn.Module): + def __init__( + self, + text_dim, + timbre_dim, + hidden, + n_lyric, + n_timbre, + heads, + kv, + head_dim, + inter, + eps=1e-6, + ): + super().__init__() + self.text_projector = nn.Linear(text_dim, hidden, bias=False) + self.lyric_encoder = LyricEncoder( + text_dim, hidden, n_lyric, heads, kv, head_dim, inter, eps + ) + self.timbre_encoder = TimbreEncoder( + timbre_dim, hidden, n_timbre, heads, kv, head_dim, inter, eps + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, text_h, text_m, lyric_h, lyric_m, refer_packed, refer_order): + text_proj = self.text_projector(text_h) + lyric_enc = self.lyric_encoder(lyric_h) + timbre_enc, timbre_mask = self.timbre_encoder(refer_packed, refer_order) + merged, merged_m = pack_sequences(lyric_enc, timbre_enc, lyric_m, timbre_mask) + final, final_m = pack_sequences(merged, text_proj, merged_m, text_m) + return final, final_m + + +# ── DiT ── + + +class DiTModel(nn.Module): + def __init__( + self, + in_ch, + hidden, + n_layers, + heads, + kv, + head_dim, + inter, + patch, + out_ch, + layer_types=None, + sliding_window=128, + eps=1e-6, + cond_dim=None, + ): + super().__init__() + self.patch_size = patch + self.rotary_emb = RotaryEmbedding(head_dim) + self.proj_in = nn.Sequential( + nn.Identity(), nn.Conv1d(in_ch, hidden, kernel_size=patch, stride=patch) + ) + self.time_embed = TimestepEmbed(hidden) + self.time_embed_r = TimestepEmbed(hidden) + self.condition_embedder = nn.Linear(cond_dim or hidden, hidden) + lt = layer_types or [ + "sliding_attention" if i % 2 == 0 else "full_attention" + for i in range(n_layers) + ] + self.layers = nn.ModuleList( + [ + DiTLayer( + hidden, + heads, + kv, + head_dim, + inter, + eps, + sliding_window=sliding_window + if lt[i] == "sliding_attention" + else None, + ) + for i in range(n_layers) + ] + ) + self.norm_out = RMSNorm(hidden, eps) + self.proj_out = nn.Sequential( + nn.Identity(), + nn.ConvTranspose1d(hidden, out_ch, kernel_size=patch, stride=patch), + ) + self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden)) + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, x, timestep, timestep_r, attention_mask, enc_h, enc_m, context): + temb_t, proj_t = self.time_embed(timestep, dtype=x.dtype) + temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=x.dtype) + temb = temb_t + temb_r + tproj = proj_t + proj_r + + h = torch.cat([context, x], dim=-1) + orig_len = h.shape[1] + if h.shape[1] % self.patch_size != 0: + h = F.pad(h, (0, 0, 0, self.patch_size - h.shape[1] % self.patch_size)) + h = self.proj_in(h.transpose(1, 2)).transpose(1, 2) + enc = self.condition_embedder(enc_h) + cos, sin = self.rotary_emb(h, h.shape[1]) + for layer in self.layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint( + layer, h, tproj, enc, (cos, sin), use_reentrant=False + ) + else: + h = layer(h, tproj, enc, (cos, sin)) + shift, scale = (self.scale_shift_table.to(temb) + temb.unsqueeze(1)).chunk( + 2, dim=1 + ) + h = self.norm_out(h) * (1 + scale) + shift + h = self.proj_out(h.transpose(1, 2)).transpose(1, 2) + return h[:, :orig_len, :] + + +# ── Top-level model ── + + +class AceStep15(nn.Module): + def __init__( + self, + hidden=2048, + text_dim=1024, + timbre_dim=64, + out_ch=64, + n_dit=24, + n_lyric=8, + n_timbre=4, + heads=16, + kv=8, + head_dim=128, + inter=6144, + patch=2, + in_ch=192, + sliding_window=128, + eps=1e-6, + layer_types=None, + # Encoder can have different size than decoder (XL models) + enc_hidden=None, + enc_heads=None, + enc_kv=None, + enc_inter=None, + ): + super().__init__() + eh = enc_hidden or hidden + eheads = enc_heads or heads + ekv = enc_kv or kv + einter = enc_inter or inter + + self.decoder = DiTModel( + in_ch, + hidden, + n_dit, + heads, + kv, + head_dim, + inter, + patch, + out_ch, + layer_types, + sliding_window, + eps, + cond_dim=eh, + ) + self.encoder = ConditionEncoder( + text_dim, + timbre_dim, + eh, + n_lyric, + n_timbre, + eheads, + ekv, + head_dim, + einter, + eps, + ) + self.null_condition_emb = nn.Parameter(torch.empty(1, 1, eh)) + self._gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + @property + def gradient_checkpointing(self): + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, value): + self._gradient_checkpointing = value + self.decoder.gradient_checkpointing = value + + def prepare_condition( + self, + text_h, + text_m, + lyric_h, + lyric_m, + refer_packed, + refer_order, + src_latents, + chunk_masks, + ): + enc_h, enc_m = self.encoder( + text_h, text_m, lyric_h, lyric_m, refer_packed, refer_order + ) + context = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) + return enc_h, enc_m, context + + +# ═══════════════════════════════════════════════════════════════════════════════ +# VAE (ComfyUI Oobleck style — uses parametrizations.weight_norm) +# ═══════════════════════════════════════════════════════════════════════════════ + + +def WNConv1d(*args, **kwargs): + return torch.nn.utils.parametrizations.weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvT1d(*args, **kwargs): + return torch.nn.utils.parametrizations.weight_norm( + nn.ConvTranspose1d(*args, **kwargs) + ) + + +class SnakeBeta(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1).exp().to(x.device) + b = self.beta.unsqueeze(0).unsqueeze(-1).exp().to(x.device) + return x + (1.0 / (b + 1e-9)) * torch.sin(x * a).pow(2) + + +class ResUnit(nn.Module): + def __init__(self, ch, dilation): + super().__init__() + self.layers = nn.Sequential( + SnakeBeta(ch), + WNConv1d(ch, ch, 7, dilation=dilation, padding=(dilation * 6) // 2), + SnakeBeta(ch), + WNConv1d(ch, ch, 1), + ) + + def forward(self, x): + return x + self.layers(x) + + +class EncBlock(nn.Module): + def __init__(self, in_ch, out_ch, stride): + super().__init__() + self.layers = nn.Sequential( + ResUnit(in_ch, 1), + ResUnit(in_ch, 3), + ResUnit(in_ch, 9), + SnakeBeta(in_ch), + WNConv1d( + in_ch, out_ch, 2 * stride, stride=stride, padding=math.ceil(stride / 2) + ), + ) + + def forward(self, x): + return self.layers(x) + + +class DecBlock(nn.Module): + def __init__(self, in_ch, out_ch, stride): + super().__init__() + self.layers = nn.Sequential( + SnakeBeta(in_ch), + WNConvT1d( + in_ch, out_ch, 2 * stride, stride=stride, padding=math.ceil(stride / 2) + ), + ResUnit(out_ch, 1), + ResUnit(out_ch, 3), + ResUnit(out_ch, 9), + ) + + def forward(self, x): + return self.layers(x) + + +class VAEBottleneck(nn.Module): + def encode(self, x): + mean, scale = x.chunk(2, dim=1) + return mean + + def decode(self, x): + return x + + +class _SeqWrap(nn.Module): + """Wraps Sequential as .layers so state_dict keys match AIO format.""" + + def __init__(self, *modules): + super().__init__() + self.layers = nn.Sequential(*modules) + + def forward(self, x): + return self.layers(x) + + +class OobleckVAE(nn.Module): + def __init__( + self, + in_ch=2, + channels=128, + latent_dim=64, + c_mults=(1, 2, 4, 8, 16), + strides=(2, 4, 4, 6, 10), + ): + super().__init__() + cm = [1] + list(c_mults) + # Encoder + enc = [WNConv1d(in_ch, cm[0] * channels, 7, padding=3)] + for i in range(len(cm) - 1): + enc.append(EncBlock(cm[i] * channels, cm[i + 1] * channels, strides[i])) + enc += [ + SnakeBeta(cm[-1] * channels), + WNConv1d(cm[-1] * channels, latent_dim * 2, 3, padding=1), + ] + self.encoder = _SeqWrap(*enc) + # Decoder + dec = [WNConv1d(latent_dim, cm[-1] * channels, 7, padding=3)] + for i in range(len(cm) - 1, 0, -1): + dec.append(DecBlock(cm[i] * channels, cm[i - 1] * channels, strides[i - 1])) + dec += [ + SnakeBeta(cm[0] * channels), + WNConv1d(cm[0] * channels, in_ch, 7, padding=3, bias=False), + ] + self.decoder = _SeqWrap(*dec) + self.bottleneck = VAEBottleneck() + self.upscale_factor = math.prod(strides) + + def encode(self, x): + return self.bottleneck.encode(self.encoder(x)) + + def decode(self, x): + return self.decoder(self.bottleneck.decode(x)) + + def tiled_decode(self, x, tile_seconds=10.0, overlap_seconds=1.0): + """VRAM-light decode: split the latent into ~tile_seconds tiles with + overlap_seconds of overlap, decode each tile independently, and + linearly crossfade the overlapping audio regions.""" + z = self.bottleneck.decode(x) + tile_frames = max(1, round(tile_seconds * LATENT_RATE)) + overlap_frames = max(1, round(overlap_seconds * LATENT_RATE)) + if overlap_frames >= tile_frames: + raise ValueError("overlap_seconds must be smaller than tile_seconds") + + T = z.shape[-1] + if T <= tile_frames: + return self.decoder(z) + + step = tile_frames - overlap_frames + fade_len = overlap_frames * self.upscale_factor + out_T = T * self.upscale_factor + + out = None + ramp = None + write_pos = 0 + + for i, start in enumerate(range(0, T, step)): + end = min(start + tile_frames, T) + decoded = self.decoder(z[..., start:end]) + + if out is None: + out = decoded.new_zeros(decoded.shape[0], decoded.shape[1], out_T) + ramp = torch.linspace(0, 1, fade_len, device=decoded.device, dtype=decoded.dtype) + + if i == 0: + n = decoded.shape[-1] + out[..., :n] = decoded + write_pos = n + else: + blend_start = write_pos - fade_len + out[..., blend_start:blend_start + fade_len] = ( + out[..., blend_start:blend_start + fade_len] * (1 - ramp) + + decoded[..., :fade_len] * ramp + ) + tail = decoded.shape[-1] - fade_len + out[..., write_pos:write_pos + tail] = decoded[..., fade_len:] + write_pos += tail + + if end == T: + break + + return out[..., :write_pos] + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Text encoder (Qwen3-Embedding, just need embed_tokens + model) +# ═══════════════════════════════════════════════════════════════════════════════ + + +class TextEncoder(nn.Module): + """Wraps Qwen3 weights loaded from AIO. Forward returns last_hidden_state.""" + + def __init__(self, qwen_model): + super().__init__() + self.model = qwen_model # the inner model (layers, norm, embed_tokens) + + def encode_text(self, input_ids): + return self.model(input_ids=input_ids).last_hidden_state + + def encode_lyrics(self, input_ids): + return self.model.embed_tokens(input_ids) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Loading +# ═══════════════════════════════════════════════════════════════════════════════ + + +def infer_dit_config(dit_sd): + """Infer model config from DiT state dict tensor shapes.""" + # hidden_size from decoder norm + hidden = dit_sd["decoder.scale_shift_table"].shape[2] + # intermediate_size from MLP gate_proj + inter = dit_sd["decoder.layers.0.mlp.gate_proj.weight"].shape[0] + # num_heads from q_proj: q_proj.weight is [num_heads * head_dim, hidden] + q_size = dit_sd["decoder.layers.0.self_attn.q_proj.weight"].shape[0] + # head_dim from q_norm + head_dim = dit_sd["decoder.layers.0.self_attn.q_norm.weight"].shape[0] + heads = q_size // head_dim + # num_kv_heads from k_proj + k_size = dit_sd["decoder.layers.0.self_attn.k_proj.weight"].shape[0] + kv = k_size // head_dim + # num_dit_layers: count unique layer indices + n_dit = ( + max(int(k.split(".")[2]) for k in dit_sd if k.startswith("decoder.layers.")) + 1 + ) + # encoder hidden (may differ from decoder hidden for XL models) + enc_hidden = dit_sd["encoder.text_projector.weight"].shape[0] + # encoder layers + n_lyric = ( + max( + int(k.split(".")[3]) + for k in dit_sd + if k.startswith("encoder.lyric_encoder.layers.") + ) + + 1 + ) + n_timbre = ( + max( + int(k.split(".")[3]) + for k in dit_sd + if k.startswith("encoder.timbre_encoder.layers.") + ) + + 1 + ) + # encoder attention config + enc_heads = ( + dit_sd["encoder.lyric_encoder.layers.0.self_attn.q_proj.weight"].shape[0] + // head_dim + ) + enc_kv = ( + dit_sd["encoder.lyric_encoder.layers.0.self_attn.k_proj.weight"].shape[0] + // head_dim + ) + enc_inter = dit_sd["encoder.lyric_encoder.layers.0.mlp.gate_proj.weight"].shape[0] + config = dict( + hidden=hidden, + inter=inter, + heads=heads, + kv=kv, + head_dim=head_dim, + n_dit=n_dit, + n_lyric=n_lyric, + n_timbre=n_timbre, + enc_hidden=enc_hidden, + enc_heads=enc_heads, + enc_kv=enc_kv, + enc_inter=enc_inter, + ) + print( + f" Detected config: hidden={hidden}, inter={inter}, heads={heads}, kv={kv}, " + f"n_dit={n_dit}, enc_hidden={enc_hidden}" + ) + return config + + +def load_models(checkpoint_path, device="cuda", dtype=torch.bfloat16): + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + print(f"Loading from: {checkpoint_path}") + + sd = load_file(checkpoint_path) + + # --- DiT --- + print(" Loading DiT...") + dit_sd = { + k.removeprefix("model.diffusion_model."): v + for k, v in sd.items() + if k.startswith("model.diffusion_model.") + } + cfg = infer_dit_config(dit_sd) + model = AceStep15( + hidden=cfg["hidden"], + inter=cfg["inter"], + heads=cfg["heads"], + kv=cfg["kv"], + head_dim=cfg["head_dim"], + n_dit=cfg["n_dit"], + n_lyric=cfg["n_lyric"], + n_timbre=cfg["n_timbre"], + enc_hidden=cfg["enc_hidden"], + enc_heads=cfg["enc_heads"], + enc_kv=cfg["enc_kv"], + enc_inter=cfg["enc_inter"], + ) + missing, unexpected = model.load_state_dict(dit_sd, strict=False) + # tokenizer/detokenizer keys are expected to be unused (cover mode only) + unexpected = [ + k for k in unexpected if not k.startswith(("tokenizer.", "detokenizer.")) + ] + if missing: + print(f" DiT missing: {len(missing)} (first 3: {missing[:3]})") + if unexpected: + print(f" DiT unexpected: {len(unexpected)} (first 3: {unexpected[:3]})") + model = model.to(device).to(dtype).eval() + + # --- VAE --- + print(" Loading VAE...") + vae_sd = {k.removeprefix("vae."): v for k, v in sd.items() if k.startswith("vae.")} + vae = OobleckVAE() + m, u = vae.load_state_dict(vae_sd, strict=False) + if m: + print(f" VAE missing: {len(m)} (first 3: {m[:3]})") + if u: + print(f" VAE unexpected: {len(u)}") + vae = vae.to(device).to(dtype).eval() + + # --- Text encoder (Qwen3-Embedding from AIO) --- + print(" Loading text encoder...") + te_sd = { + k.removeprefix("text_encoders.qwen3_06b.transformer.model."): v + for k, v in sd.items() + if k.startswith("text_encoders.qwen3_06b.transformer.model.") + } + # Load Qwen3 model structure from transformers, then override weights + from transformers import Qwen3Model, Qwen3Config + + qwen_cfg = Qwen3Config( + vocab_size=151669, + hidden_size=1024, + intermediate_size=3072, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + max_position_embeddings=32768, + rms_norm_eps=1e-6, + ) + qwen = Qwen3Model(qwen_cfg) + m2, u2 = qwen.load_state_dict(te_sd, strict=False) + if m2: + print(f" TE missing: {len(m2)} (first 3: {m2[:3]})") + te = TextEncoder(qwen).to(device).to(dtype).eval() + + # Tokenizer — download from HF + print(" Loading tokenizer...") + tok = AutoTokenizer.from_pretrained( + "Qwen/Qwen3-Embedding-0.6B", trust_remote_code=False + ) + + del sd # free memory + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + print(" Done.\n") + return dict( + model=model, vae=vae, text_encoder=te, tokenizer=tok, device=device, dtype=dtype + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Inference +# ═══════════════════════════════════════════════════════════════════════════════ + + +@torch.inference_mode() +def get_latent(audio_path, models): + """Encode audio file to VAE latent. Returns [1, 64, T] tensor.""" + vae, device, dtype = models["vae"], models["device"], models["dtype"] + wav, sr = torchaudio.load(audio_path) + if sr != SAMPLE_RATE: + wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) + if wav.shape[0] == 1: + wav = wav.repeat(2, 1) + elif wav.shape[0] > 2: + wav = wav[:2] + return vae.encode(wav.unsqueeze(0).to(device, dtype)) # [1, 64, T] + + +@torch.inference_mode() +def generate( + models, + prompt, + lyrics="", + duration=30.0, + seed=42, + bpm="N/A", + key="N/A", + time_sig="N/A", + language="en", + timesteps=None, + guidance_scale=1.0, +): + model = models["model"] + vae = models["vae"] + te = models["text_encoder"] + tok = models["tokenizer"] + device = models["device"] + dtype = models["dtype"] + + t_sched = timesteps + latent_len = int(duration * LATENT_RATE) + print( + f"Duration: {duration}s -> {latent_len} latent frames, {len(t_sched)} steps" + + (f", CFG={guidance_scale}" if guidance_scale > 1.0 else "") + ) + + # Silence as source latent [1, 64, T] -> [1, T, 64] for DiT + sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T] + src = sil.transpose(1, 2) # [1, T, 64] + chunk_masks = torch.ones_like(src) + + # Text encoding + metas = f"- bpm: {bpm}\n- timesignature: {time_sig}\n- keyscale: {key}\n- duration: {int(duration)} seconds\n" + caption = SFT_PROMPT.format( + instruction="Fill the audio semantic mask based on the given conditions:", + caption=prompt, + metas=metas, + ) + lyrics_text = f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>" + + cap_tok = tok(caption, truncation=True, max_length=256, return_tensors="pt") + lyr_tok = tok(lyrics_text, truncation=True, max_length=2048, return_tensors="pt") + + text_h = te.encode_text(cap_tok.input_ids.to(device)).to(dtype) + text_m = cap_tok.attention_mask.to(device).bool() + lyric_h = te.encode_lyrics(lyr_tok.input_ids.to(device)).to(dtype) + lyric_m = lyr_tok.attention_mask.to(device).bool() + + # Reference audio (silence) + ref = sil[:, :, :750].transpose(1, 2) # [1, 750, 64] + ref_order = torch.zeros(1, device=device, dtype=torch.long) + + # Prepare conditions (conditional) + print("Preparing conditions...") + enc_h, enc_m, ctx = model.prepare_condition( + text_h, text_m, lyric_h, lyric_m, ref, ref_order, src, chunk_masks + ) + + # Prepare unconditional conditions for CFG + use_cfg = guidance_scale > 1.0 + enc_h_uncond = None + if use_cfg: + enc_h_uncond = model.null_condition_emb.expand_as(enc_h) + + # Noise + gen = torch.Generator(device=device).manual_seed(seed) + noise_ch = ctx.shape[-1] // 2 + xt = torch.randn(1, latent_len, noise_ch, generator=gen, device=device, dtype=dtype) + + # Diffusion + print("Running diffusion...") + t0 = time.time() + t_sched_t = torch.tensor(t_sched, device=device, dtype=dtype) + attn = torch.ones(1, latent_len, device=device, dtype=dtype) + + for i in range(len(t_sched_t)): + tv = t_sched_t[i].item() + tt = torch.full((1,), tv, device=device, dtype=dtype) + + vt_cond = model.decoder(xt, tt, tt, attn, enc_h, enc_m, ctx) + + if use_cfg: + vt_uncond = model.decoder(xt, tt, tt, attn, enc_h_uncond, enc_m, ctx) + vt = vt_uncond + guidance_scale * (vt_cond - vt_uncond) + else: + vt = vt_cond + + if i == len(t_sched_t) - 1: + xt = xt - vt * tv + else: + xt = xt - vt * (tv - t_sched_t[i + 1].item()) + + print(f"Diffusion: {time.time() - t0:.2f}s") + + # VAE decode + print("Decoding audio...") + t0 = time.time() + wav = vae.decode(xt.transpose(1, 2)) # [1, 2, samples] + wav = wav[0, :, : int(duration * SAMPLE_RATE)] + print(f"VAE decode: {time.time() - t0:.2f}s") + return wav.cpu().float() + + +# ═══════════════════════════════════════════════════════════════════════════════ +# CLI +# ═══════════════════════════════════════════════════════════════════════════════ + + +def main(): + p = argparse.ArgumentParser(description="ACE-Step v1.5 standalone inference") + p.add_argument("--prompt", required=True) + p.add_argument("--lyrics", default="") + p.add_argument("--duration", type=float, default=30.0) + p.add_argument("--output", default="output.wav") + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--model", + default="base", + choices=["base", "turbo"], + help="Model variant (default: base)", + ) + p.add_argument( + "--checkpoint", default=None, help="Override path to AIO .safetensors" + ) + p.add_argument("--device", default=None) + p.add_argument( + "--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"] + ) + p.add_argument("--bpm", default="N/A") + p.add_argument("--key", default="N/A") + p.add_argument("--time-sig", default="N/A") + p.add_argument("--language", default="en") + p.add_argument( + "--steps", + type=int, + default=None, + help="Diffusion steps (default: 30 for base, 8 for turbo)", + ) + p.add_argument( + "--shift", type=float, default=3.0, help="Timestep shift (default: 3.0)" + ) + p.add_argument( + "--cfg", + type=float, + default=None, + help="CFG guidance scale (default: 3.5 for base, 1.0 for turbo)", + ) + args = p.parse_args() + + device = args.device or ( + "cuda" + if torch.cuda.is_available() + else "mps" + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() + else "cpu" + ) + dtype = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + }[args.dtype] + if device == "mps": + dtype = torch.float32 + + lyrics = args.lyrics + if lyrics.startswith("@") and os.path.isfile(lyrics[1:]): + lyrics = open(lyrics[1:]).read() + else: + lyrics = lyrics.replace("\\n", "\n") + + # Model-specific defaults + is_turbo = args.model == "turbo" + ckpt = args.checkpoint or MODEL_PATHS[args.model] + steps = args.steps or (8 if is_turbo else 30) + cfg = args.cfg if args.cfg is not None else (1.0 if is_turbo else 3.5) + + # Timestep schedule + if is_turbo and steps == 8: + ts = TURBO_TIMESTEPS.get(args.shift, TURBO_TIMESTEPS[3.0]) + else: + ts = compute_timesteps(steps, args.shift) + + print( + f"ACE-Step v1.5 ({args.model}) | {device} ({dtype}) | seed={args.seed} | {args.duration}s | {steps} steps | CFG={cfg}" + ) + models = load_models(ckpt, device, dtype) + wav = generate( + models, + args.prompt, + lyrics, + args.duration, + args.seed, + args.bpm, + args.key, + args.time_sig, + args.language, + ts, + cfg, + ) + + os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) + torchaudio.save(args.output, wav, SAMPLE_RATE) + print(f"Saved: {args.output} ({wav.shape[1] / SAMPLE_RATE:.1f}s stereo)") + + +if __name__ == "__main__": + main() diff --git a/ai-toolkit/extensions_built_in/audio_models/ace_step/src/pipeline.py b/ai-toolkit/extensions_built_in/audio_models/ace_step/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..57edff6404514985d2ab8c525de29b6ee9e48abc --- /dev/null +++ b/ai-toolkit/extensions_built_in/audio_models/ace_step/src/pipeline.py @@ -0,0 +1,167 @@ +from typing import List, Optional + +import torch +import time +import os +from .model import ( + SAMPLE_RATE, + AceStep15, + OobleckVAE, + TextEncoder, + get_silence_latent, + compute_timesteps, +) +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer + +SFT_PROMPT = """# Instruction +{instruction} + +# Caption +{caption} + +# Metas +{metas}<|endoftext|> +""" + + +class AceStep15Pipeline: + SAMPLE_RATE = 48000 + LATENT_RATE = 25 # 48000 / 1920 + SFT_PROMPT = SFT_PROMPT + + def __init__(self, transformer, vae, text_encoder, tokenizer, scheduler): + self.transformer: AceStep15 = transformer + self.vae: OobleckVAE = vae + self.text_encoder: TextEncoder = text_encoder + self.tokenizer: AutoTokenizer = tokenizer + self.scheduler = scheduler + self.do_tiled_decoding = False + + def to(self, *args, **kwargs): + self.transformer.to(*args, **kwargs) + self.vae.to(*args, **kwargs) + self.text_encoder.to(*args, **kwargs) + + def get_text_embedings( + self, prompt, lyrics, bpm, key, time_sig, duration, language + ): + metas = f"- bpm: {bpm}\n- timesignature: {time_sig}\n- keyscale: {key}\n- duration: {int(duration)} seconds\n" + caption = self.SFT_PROMPT.format( + instruction="Fill the audio semantic mask based on the given conditions:", + caption=prompt, + metas=metas, + ) + lyrics_text = f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>" + + cap_tok = self.tokenizer( + caption, truncation=True, max_length=256, return_tensors="pt" + ) + lyr_tok = self.tokenizer( + lyrics_text, truncation=True, max_length=2048, return_tensors="pt" + ) + + text_embeddings = self.text_encoder.encode_text( + cap_tok.input_ids.to(self.text_encoder.device) + ).to(self.transformer.dtype) + text_mask = cap_tok.attention_mask.to(self.text_encoder.device).bool() + lyric_embeddings = self.text_encoder.encode_lyrics( + lyr_tok.input_ids.to(self.text_encoder.device) + ).to(self.transformer.dtype) + lyric_mask = lyr_tok.attention_mask.to(self.text_encoder.device).bool() + + return text_embeddings, text_mask, lyric_embeddings, lyric_mask + + def __call__( + self, + prompt="", + lyrics="", + encoder_embeddings: Optional[List[torch.Tensor]] = None, + encoder_mask: Optional[List[torch.Tensor]] = None, + # uses a null conditional for unconditional if not provided, which is what we want for CFG + num_inference_steps=50, + duration=30.0, + generator: torch.Generator = None, + bpm="N/A", + key="N/A", + time_sig="N/A", + language="en", + guidance_scale=1.0, + ): + t_sched = compute_timesteps(num_inference_steps, 3.0) + latent_len = int(duration * self.LATENT_RATE) + device = self.transformer.device + dtype = self.transformer.dtype + + # Text encoding + if encoder_embeddings is not None and encoder_mask is not None: + enc_h = encoder_embeddings + enc_m = encoder_mask + sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T] + src = sil.transpose(1, 2) # [1, T, 64] + chunk_masks = torch.ones_like(src) + ctx = torch.cat([src, chunk_masks.to(src.dtype)], dim=-1) + else: + text_h, text_m, lyric_h, lyric_m = self.get_text_embedings( + prompt, lyrics, bpm, key, time_sig, duration, language + ) + + # Silence as source latent [1, 64, T] -> [1, T, 64] for DiT + sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T] + src = sil.transpose(1, 2) # [1, T, 64] + chunk_masks = torch.ones_like(src) + + # Reference audio (silence) + ref = sil[:, :, :750].transpose(1, 2) # [1, 750, 64] + ref_order = torch.zeros(1, device=device, dtype=torch.long) + + # Prepare conditions (conditional) + enc_h, enc_m, ctx = self.transformer.prepare_condition( + text_h, text_m, lyric_h, lyric_m, ref, ref_order, src, chunk_masks + ) + + # Prepare unconditional conditions for CFG + use_cfg = guidance_scale > 1.0 + enc_h_uncond = None + if use_cfg: + enc_h_uncond = self.transformer.null_condition_emb.expand_as(enc_h) + + # Noise + if generator is None: + generator = torch.Generator(device=device) + noise_ch = ctx.shape[-1] // 2 + xt = randn_tensor( + (1, latent_len, noise_ch), generator=generator, device=device, dtype=dtype + ) + # xt = torch.randn(1, latent_len, noise_ch, generator=generator, device=device, dtype=dtype) + + # Diffusion + t_sched_t = torch.tensor(t_sched, device=device, dtype=dtype) + attn = torch.ones(1, latent_len, device=device, dtype=dtype) + + for i in range(len(t_sched_t)): + tv = t_sched_t[i].item() + tt = torch.full((1,), tv, device=device, dtype=dtype) + + vt_cond = self.transformer.decoder(xt, tt, tt, attn, enc_h, enc_m, ctx) + + if use_cfg: + vt_uncond = self.transformer.decoder( + xt, tt, tt, attn, enc_h_uncond, enc_m, ctx + ) + vt = vt_uncond + guidance_scale * (vt_cond - vt_uncond) + else: + vt = vt_cond + + if i == len(t_sched_t) - 1: + xt = xt - vt * tv + else: + xt = xt - vt * (tv - t_sched_t[i + 1].item()) + + # VAE decode + if self.do_tiled_decoding: + wav = self.vae.tiled_decode(xt.transpose(1, 2)) # [1, 2, samples] + else: + wav = self.vae.decode(xt.transpose(1, 2)) # [1, 2, samples] + wav = wav[0, :, : int(duration * SAMPLE_RATE)] + return wav diff --git a/ai-toolkit/extensions_built_in/audio_models/base_audio_model.py b/ai-toolkit/extensions_built_in/audio_models/base_audio_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6860a1aaf4466414da60ac1f4904bfc27f07013b --- /dev/null +++ b/ai-toolkit/extensions_built_in/audio_models/base_audio_model.py @@ -0,0 +1,99 @@ +import json + +import torch + +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds + + +class BaseAudioModel(BaseModel): + sample_rate = 48000 + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_audio_model = True + + def generate_single_image( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # This is called on the base model. We override it to make it make more sense for audio models. + return self.generate_single_audio( + pipeline, + gen_config, + conditional_embeds, + unconditional_embeds, + generator, + extra, + ) + + def generate_single_audio( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # This is called on the base model. We override it to make it make more sense for audio models. + raise NotImplementedError( + "generate_single_audio is not implemented for this model" + ) + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + # we need to save the model, vae, text encoder, and tokenizer together since they are all trained together and depend on each other + raise NotImplementedError( + "save_model is not implemented for this model. Use the pipeline directly instead." + ) + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def encode_images(self, image_list: torch.Tensor, device=None, dtype=None): + # make it more obvious for audio models + return self.encode_audio(image_list, device=device, dtype=dtype) + + def encode_audio(self, audio_tensor: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.device_torch + if dtype is None: + dtype = self.torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + return self.vae.encode(audio_tensor.to(device=device, dtype=dtype)) diff --git a/ai-toolkit/extensions_built_in/captioner/AceStepCaptioner.py b/ai-toolkit/extensions_built_in/captioner/AceStepCaptioner.py new file mode 100644 index 0000000000000000000000000000000000000000..ce89a2f9cf0c1b509a5dd28305f8f94230506f24 --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/AceStepCaptioner.py @@ -0,0 +1,261 @@ +from typing import Optional + +import librosa +import numpy as np +import torch +import torchaudio +from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor +from collections import OrderedDict + +from optimum.quanto import freeze +from toolkit.basic import flush +from toolkit.util.quantize import quantize, get_qtype + +from .BaseCaptioner import BaseCaptioner, CaptionConfig +import transformers +import logging +import warnings + +# transformers.logging.set_verbosity_error() +warnings.filterwarnings("ignore") +logging.disable(logging.WARNING) + +TARGET_SAMPLE_RATE = 16000 +CAPTIONER_ID = "ACE-Step/acestep-captioner" +TRANSCRIBER_ID = "ACE-Step/acestep-transcriber" + +# Key profiles for Krumhansl-Schmuckler key detection +MAJOR_PROFILE = np.array( + [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88] +) +MINOR_PROFILE = np.array( + [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17] +) +KEY_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Audio analysis (BPM, key, time signature) via librosa +# ═══════════════════════════════════════════════════════════════════════════════ + + +def analyze_audio(audio_path): + """Extract BPM, key, and time signature from audio using librosa.""" + y, sr = librosa.load(audio_path, sr=22050, mono=True) + duration = librosa.get_duration(y=y, sr=sr) + + # BPM + tempo, _ = librosa.beat.beat_track(y=y, sr=sr) + if hasattr(tempo, "__len__"): + tempo = tempo[0] + bpm = int(round(float(tempo))) + + # Key detection via chroma correlation with key profiles + chroma = librosa.feature.chroma_cqt(y=y, sr=sr) + chroma_avg = chroma.mean(axis=1) + major_corrs = np.array( + [np.corrcoef(np.roll(MAJOR_PROFILE, i), chroma_avg)[0, 1] for i in range(12)] + ) + minor_corrs = np.array( + [np.corrcoef(np.roll(MINOR_PROFILE, i), chroma_avg)[0, 1] for i in range(12)] + ) + + best_major_idx = major_corrs.argmax() + best_minor_idx = minor_corrs.argmax() + if major_corrs[best_major_idx] >= minor_corrs[best_minor_idx]: + keyscale = f"{KEY_NAMES[best_major_idx]} major" + else: + keyscale = f"{KEY_NAMES[best_minor_idx]} minor" + + # Time signature estimation from beat strength pattern + onset_env = librosa.onset.onset_strength(y=y, sr=sr) + tempo_est, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr) + if len(beats) >= 8: + beat_strengths = onset_env[beats] + # Check 3/4 vs 4/4 by looking at periodicity of strong beats + acf = np.correlate( + beat_strengths - beat_strengths.mean(), + beat_strengths - beat_strengths.mean(), + mode="full", + ) + acf = acf[len(acf) // 2 :] + if len(acf) > 6: + # Look at autocorrelation peaks at lag 3 vs lag 4 + score_3 = acf[3] if len(acf) > 3 else 0 + score_4 = acf[4] if len(acf) > 4 else 0 + timesig = "3" if score_3 > score_4 * 1.2 else "4" + else: + timesig = "4" + else: + timesig = "4" + + return { + "bpm": bpm, + "keyscale": keyscale, + "timesignature": timesig, + "duration": int(round(duration)), + } + + +class AceStepCaptionConfig(CaptionConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.fixed_caption: Optional[str] = kwargs.get("fixed_caption", None) + + +class AceStepCaptioner(BaseCaptioner): + caption_config_class = AceStepCaptionConfig + caption_config: AceStepCaptionConfig + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(AceStepCaptioner, self).__init__(process_id, job, config, **kwargs) + + def load_model(self): + self.print_and_status_update("Loading transcriber model") + self.model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + self.caption_config.model_name_or_path, + dtype=self.torch_dtype, + device_map="cpu", + ) + self.model.to(self.device_torch) + self.model.disable_talker() + if self.caption_config.quantize: + self.print_and_status_update("Quantizing transcriber model") + quantize(self.model, weights=get_qtype(self.caption_config.qtype)) + freeze(self.model) + flush() + self.processor = Qwen2_5OmniProcessor.from_pretrained( + self.caption_config.model_name_or_path + ) + if self.caption_config.low_vram: + self.model.to("cpu") + + self.model2 = None + self.processor2 = None + + if self.caption_config.fixed_caption is not None: + # load captioner model + self.print_and_status_update("Loading captioner model") + self.model2 = Qwen2_5OmniForConditionalGeneration.from_pretrained( + self.caption_config.model_name_or_path2, + dtype=self.torch_dtype, + device_map="cpu", + ) + self.model2.to(self.device_torch) + self.model2.disable_talker() + if self.caption_config.quantize: + self.print_and_status_update("Quantizing captioner model") + quantize(self.model2, weights=get_qtype(self.caption_config.qtype)) + freeze(self.model2) + flush() + self.processor2 = Qwen2_5OmniProcessor.from_pretrained( + self.caption_config.model_name_or_path2, + ) + + if self.caption_config.low_vram: + self.model2.to("cpu") + flush() + + def run_qwen_audio(self, model, processor, audio_data, sr, prompt_text): + """Run a Qwen2.5-Omni model on audio with a text prompt.""" + conversation = [ + { + "role": "user", + "content": [ + {"type": "audio", "audio": "<|audio_bos|><|AUDIO|><|audio_eos|>"}, + {"type": "text", "text": prompt_text}, + ], + } + ] + text = processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + inputs = processor( + text=text, + audio=[audio_data], + images=None, + videos=None, + return_tensors="pt", + padding=True, + sampling_rate=sr, + ) + inputs = inputs.to(model.device).to(model.dtype) + text_ids = model.generate(**inputs, return_audio=False) + output = processor.batch_decode( + text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + result = output[0] + marker = "assistant\n" + if marker in result: + result = result[result.rfind(marker) + len(marker) :] + return result.strip() + + def get_audio_lyrics(self, audio_data: torch.Tensor) -> str: + if self.caption_config.low_vram and self.model2.device != torch.device("cpu"): + # move captioner to cpu + self.model2.to("cpu") + # move lyric model if needed + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + prompt_text = "*Task* Transcribe this audio in detail" + return self.run_qwen_audio( + self.model, self.processor, audio_data, TARGET_SAMPLE_RATE, prompt_text + ) + + def get_audio_caption(self, audio_data: torch.Tensor) -> str: + if self.caption_config.low_vram and self.model.device != torch.device("cpu"): + # move lyricmodel to cpu + self.model.to("cpu") + # move captioner model if needed + if self.model2.device == torch.device("cpu"): + self.model2.to(self.device_torch) + prompt_text = "*Task* Describe this music in detail. Include genre, mood, instrumentation, tempo feel, and vocal style if present." + return self.run_qwen_audio( + self.model2, self.processor2, audio_data, TARGET_SAMPLE_RATE, prompt_text + ) + + def get_caption_for_file(self, file_path: str) -> str: + try: + # analyze audio with librosa + analysis = analyze_audio(file_path) + + # load audio with torchaudio for transcription + waveform, sr = torchaudio.load(file_path) + waveform = waveform.to(self.device_torch) + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if sr != TARGET_SAMPLE_RATE: + waveform = torchaudio.functional.resample( + waveform, sr, TARGET_SAMPLE_RATE + ) + audio_data = waveform.squeeze(0).cpu().numpy() + + # get the lyrics from the audio + lyrics = self.get_audio_lyrics(audio_data) + + language = "en" + + if "# Languages" in lyrics and "# Lyrics" in lyrics: + language = lyrics.split("# Languages")[1].split("# Lyrics")[0] + # remove newlines and extra spaces from language + language = language.replace("\n", "").strip() + lyrics = lyrics.split("# Lyrics")[1].strip() + + # get the caption from the audio + if self.caption_config.fixed_caption is not None: + caption = self.caption_config.fixed_caption + else: + caption = self.get_audio_caption(audio_data) + + output = f"\n{caption}\n\n" + output += f"\n{lyrics}\n\n" + output += f"{analysis['bpm']}\n" + output += f"{analysis['keyscale']}\n" + output += f"{analysis['timesignature']}\n" + output += f"{analysis['duration']}\n" + output += f"{language}" + return output + except Exception as e: + print(f"Error processing {file_path}: {e}") + return None diff --git a/ai-toolkit/extensions_built_in/captioner/BaseCaptioner.py b/ai-toolkit/extensions_built_in/captioner/BaseCaptioner.py new file mode 100644 index 0000000000000000000000000000000000000000..cb63002b0b2d4113bc5f76c61ef9e91613e278cf --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/BaseCaptioner.py @@ -0,0 +1,439 @@ +import asyncio +from collections import OrderedDict + +import sqlite3 +import os +from typing import Literal, Optional +import threading +import time +import signal +import concurrent.futures +from PIL import Image + +import torch +from jobs.process import BaseExtensionProcess +import tqdm + +from toolkit.train_tools import get_torch_dtype + +AITK_Status = Literal["running", "stopped", "error", "completed"] + + +class CaptionConfig: + def __init__(self, **kwargs): + self.model_name_or_path = kwargs.get("model_name_or_path", None) + if self.model_name_or_path is None: + raise ValueError("model_name_or_path is required in config") + self.model_name_or_path2 = kwargs.get("model_name_or_path2", None) + self.extensions = kwargs.get("extensions", []) + if self.extensions is None or len(self.extensions) == 0: + raise ValueError("At least one extension is required in config") + self.path_to_caption = kwargs.get("path_to_caption", None) + if self.path_to_caption is None: + raise ValueError("path_to_caption is required in config") + self.dtype = kwargs.get("dtype", "bf16") + self.device = kwargs.get("device", "cuda") + self.quantize = kwargs.get("quantize", False) + self.qtype = kwargs.get("qtype", "float8") + self.low_vram = kwargs.get("low_vram", False) + self.caption_extension = kwargs.get("caption_extension", "txt") + self.recaption = kwargs.get("recaption", False) + self.max_res = kwargs.get("max_res", 512) + self.max_new_tokens = kwargs.get("max_new_tokens", 128) + self.caption_prompt = kwargs.get( + "caption_prompt", "Describe this image in detail." + ) + self.compile = kwargs.get("compile", False) + + +class BaseCaptioner(BaseExtensionProcess): + caption_config_class = CaptionConfig + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(BaseCaptioner, self).__init__(process_id, job, config, **kwargs) + self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") + self.job_id = os.environ.get("AITK_JOB_ID", None) + self.job_id = self.job_id.strip() if self.job_id is not None else None + self.is_ui_captioner = True + if not os.path.exists(self.sqlite_db_path): + self.is_ui_captioner = False + else: + print(f"Using SQLite database at {self.sqlite_db_path}") + if self.job_id is None: + self.is_ui_captioner = False + else: + print(f'Job ID: "{self.job_id}"') + + self.is_stopping = False + + if self.is_ui_captioner: + self.is_stopping = False + # Create a thread pool for database operations + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Track all async tasks + self._async_tasks = [] + # Initialize the status + self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + # self.start_stop_watcher(interval_sec=2.0) + + self.caption_config = self.caption_config_class(**self.get_conf("caption", {})) + self.model = None + self.processor = None + self.model2 = None + self.processor2 = None + self.file_paths = [] + self.step_num = 0 + self.device_torch = torch.device(self.caption_config.device) + self.torch_dtype = get_torch_dtype(self.caption_config.dtype) + + def run(self): + super(BaseCaptioner, self).run() + with torch.no_grad(): + self.start_stop_watcher() + self.update_status("running", "Loading Model") + self.load_model() + self.maybe_compile_models() + self.update_status("running", "Looking for files") + self.find_files() + self.update_db_key("total_steps", len(self.file_paths)) + self.update_step() + self.update_status("running", f"Captioning {len(self.file_paths)} files") + self.run_caption_loop() + self.update_status("completed", "Captioning completed") + print("") + + print("****************************************************") + print("Captioning complete") + print("****************************************************") + + def run_caption_loop(self): + for file_path in tqdm.tqdm( + self.file_paths, desc="Captioning files", unit="file" + ): + if self.is_ui_captioner: + self.maybe_stop() + if self.is_stopping: + break + try: + file_caption = self.get_caption_for_file(file_path) + if file_caption is not None: + self.save_caption_for_file(file_path, file_caption) + except Exception as e: + print(f"Error captioning file {file_path}: {e}") + continue + finally: + self.step_num += 1 + self.update_step() + + def load_pil_image(self, file_path: str, max_res: Optional[int] = None) -> Image: + image = Image.open(file_path).convert("RGB") + if max_res is not None: + max_pixels = max_res * max_res + image_pixels = image.width * image.height + if image_pixels > max_pixels: + scale_factor = (max_pixels / image_pixels) ** 0.5 + new_width = int(image.width * scale_factor) + new_height = int(image.height * scale_factor) + image = image.resize((new_width, new_height), resample=Image.BICUBIC) + return image + + def save_caption_for_file(self, file_path: str, caption: str): + filename_no_ext = os.path.splitext(file_path)[0] + caption_file_path = f"{filename_no_ext}.{self.caption_config.caption_extension}" + # delete it if it already exists + if os.path.exists(caption_file_path): + os.remove(caption_file_path) + with open(caption_file_path, "w", encoding="utf-8") as f: + f.write(caption) + + def get_caption_for_file(self, file_path: str) -> str: + raise NotImplementedError("Captioning not implemented for this captioner") + + def print_and_status_update(self, status: str): + print(status) + self.update_status("running", status) + + def find_files(self): + # recursivly find all the files in the path_to_caption with the specified extensions and save the paths to self.file_paths + for root, dirs, files in os.walk(self.caption_config.path_to_caption): + dirs[:] = [d for d in dirs if d != "_controls"] + for file in files: + if any( + file.lower().endswith(f".{ext}") and not file.startswith(".") + for ext in self.caption_config.extensions + ): + full_path = os.path.join(root, file) + self.file_paths.append(full_path) + # sort + self.file_paths.sort() + # it not recaption, remove the ones with captions + if not self.caption_config.recaption: + filtered_file_paths = [] + for file_path in self.file_paths: + filename_no_ext = os.path.splitext(file_path)[0] + caption_file_path = ( + f"{filename_no_ext}.{self.caption_config.caption_extension}" + ) + has_caption = False + if os.path.exists(caption_file_path): + with open(caption_file_path, "r", encoding="utf-8") as f: + has_caption = f.read().strip() != "" + if not has_caption: + filtered_file_paths.append(file_path) + print( + f"Found {len(self.file_paths)} files. {len(filtered_file_paths)} need captioning." + ) + self.file_paths = filtered_file_paths + else: + print(f"Found {len(self.file_paths)} files to caption") + + def load_model(self): + raise NotImplementedError("Model loading not implemented for this captioner") + + def maybe_compile_models(self): + if not self.caption_config.compile: + return + import importlib.util + + if importlib.util.find_spec("triton") is None: + print( + "[AITK] compile requested but triton is not installed, skipping compilation." + ) + return + try: + # compilation happens lazily on first forward, so fall back to + # eager there too if the backend fails (e.g. broken triton install) + torch._dynamo.config.suppress_errors = True + for model in [self.model, self.model2]: + if model is not None and isinstance(model, torch.nn.Module): + # dynamic=True avoids recompiling for every new image/token shape + model.compile(dynamic=True) + print( + "[AITK] Model compilation enabled. The first few items will be slow while the model compiles." + ) + except Exception as e: + print(f"[AITK] Failed to compile model, continuing without compile: {e}") + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if not self.is_ui_captioner: + return + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) + + def _run_async_operation(self, coro): + """Helper method to run an async coroutine and track the task.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a task and track it + if loop.is_running(): + task = asyncio.run_coroutine_threadsafe(coro, loop) + self._async_tasks.append(asyncio.wrap_future(task)) + else: + task = loop.create_task(coro) + self._async_tasks.append(task) + loop.run_until_complete(task) + + async def _execute_db_operation(self, operation_func): + """Execute a database operation in a separate thread with retry on lock.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.thread_pool, lambda: self._retry_db_operation(operation_func) + ) + + def _db_connect(self): + """Create a new connection for each operation to avoid locking.""" + conn = sqlite3.connect(self.sqlite_db_path, timeout=30.0) + conn.isolation_level = None # Enable autocommit mode + return conn + + def _retry_db_operation(self, operation_func, max_retries=3, base_delay=2.0): + """Retry a database operation with exponential backoff on lock errors.""" + last_error = None + for attempt in range(max_retries + 1): + try: + return operation_func() + except sqlite3.OperationalError as e: + if "database is locked" in str(e): + last_error = e + if attempt < max_retries: + delay = base_delay * (2**attempt) # 2s, 4s, 8s + print( + f"[AITK] Database locked (attempt {attempt + 1}/{max_retries + 1}), retrying in {delay:.1f}s..." + ) + time.sleep(delay) + else: + print( + f"[AITK] Database locked after {max_retries + 1} attempts, giving up." + ) + else: + raise + raise last_error + + def should_stop(self): + if not self.is_ui_captioner: + return False + + def _check_stop(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("SELECT stop FROM Job WHERE id = ?", (self.job_id,)) + stop = cursor.fetchone() + return False if stop is None else stop[0] == 1 + + return self._retry_db_operation(_check_stop) + + def should_return_to_queue(self): + if not self.is_ui_captioner: + return False + + def _check_return_to_queue(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,) + ) + return_to_queue = cursor.fetchone() + return False if return_to_queue is None else return_to_queue[0] == 1 + + return self._retry_db_operation(_check_return_to_queue) + + def maybe_stop(self): + if not self.is_ui_captioner: + return + if self.should_stop(): + self._run_async_operation(self._update_status("stopped", "Job stopped")) + self.is_stopping = True + raise Exception("Job stopped") + if self.should_return_to_queue(): + self._run_async_operation(self._update_status("queued", "Job queued")) + self.is_stopping = True + raise Exception("Job returning to queue") + + async def _update_key(self, key, value): + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + # Convert the value to string if it's not already + if isinstance(value, str): + value_to_insert = value + else: + value_to_insert = str(value) + + # Use parameterized query for both the column name and value + update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" + cursor.execute(update_query, (value_to_insert, self.job_id)) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_step(self): + """Non-blocking update of the step count.""" + if self.is_ui_captioner: + self._run_async_operation(self._update_key("step", self.step_num)) + + def update_db_key(self, key, value): + """Non-blocking update a key in the database.""" + if self.is_ui_captioner: + self._run_async_operation(self._update_key(key, value)) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): + if not self.is_ui_captioner: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + if info is not None: + cursor.execute( + "UPDATE Job SET status = ?, info = ? WHERE id = ?", + (status, info, self.job_id), + ) + else: + cursor.execute( + "UPDATE Job SET status = ? WHERE id = ?", + (status, self.job_id), + ) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_status(self, status: AITK_Status, info: Optional[str] = None): + if self.is_ui_captioner: + """Non-blocking update of status.""" + self._run_async_operation(self._update_status(status, info)) + + def on_error(self, e: Exception): + super(BaseCaptioner, self).on_error(e) + if self.is_ui_captioner: + try: + if not self.is_stopping: + self.update_status("error", str(e)) + asyncio.run(self.wait_for_all_async()) + except Exception as db_err: + print( + f"[AITK] Warning: failed to update DB during error handling: {db_err}" + ) + finally: + self.thread_pool.shutdown(wait=True) + + async def wait_for_all_async(self): + """Wait for all tracked async operations to complete.""" + if not self._async_tasks: + return + + try: + await asyncio.gather(*self._async_tasks) + except Exception as e: + pass + finally: + # Clear the task list after completion + self._async_tasks.clear() diff --git a/ai-toolkit/extensions_built_in/captioner/Ideogram4Captioner.py b/ai-toolkit/extensions_built_in/captioner/Ideogram4Captioner.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e8c08a04975f744e5e050af410ec2134ce9bed --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/Ideogram4Captioner.py @@ -0,0 +1,183 @@ +import json +import re +from math import gcd +from collections import OrderedDict +from typing import Optional + +from PIL import Image + +from .Qwen3VLCaptioner import Qwen3VLCaptioner +from .prompts.ideogram4_caption_prompt import ideogram4_caption_prompt +from toolkit.ideogram_caption import normalize_caption_dict, swap_bbox_xy_in_text +import transformers +import logging +import warnings + +# transformers.logging.set_verbosity_error() +warnings.filterwarnings("ignore") +logging.disable(logging.WARNING) + +# The deconstruction JSON is long. 128 tokens (base default) truncates it badly, +# so enforce a sane floor for this captioner unless the user asked for more. +MIN_NEW_TOKENS = 3072 + +# Largest denominator allowed when snapping a real image's aspect ratio to a +# clean W:H. Keeps captions in the same small-denominator ratio distribution the +# generator was trained on, instead of ugly fractions like 1023:768. +MAX_AR_DENOMINATOR = 16 + + +class Ideogram4Captioner(Qwen3VLCaptioner): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(Ideogram4Captioner, self).__init__(process_id, job, config, **kwargs) + if self.caption_config.max_new_tokens < MIN_NEW_TOKENS: + print( + f"[Ideogram4Captioner] Raising max_new_tokens " + f"{self.caption_config.max_new_tokens} -> {MIN_NEW_TOKENS} " + f"(the deconstruction JSON is long)." + ) + self.caption_config.max_new_tokens = MIN_NEW_TOKENS + + def compute_aspect_ratio(self, width: int, height: int) -> str: + """Return a clean 'W:H' string for the image, snapped to a small + denominator so it matches the generator's ratio distribution.""" + if width <= 0 or height <= 0: + return "1:1" + g = gcd(width, height) + rw, rh = width // g, height // g + # Already clean enough. + if rw <= MAX_AR_DENOMINATOR and rh <= MAX_AR_DENOMINATOR: + return f"{rw}:{rh}" + # Otherwise find the closest p:q (q <= MAX_AR_DENOMINATOR) to the true ratio. + target = width / height + best = None + for q in range(1, MAX_AR_DENOMINATOR + 1): + p = max(1, round(target * q)) + err = abs(p / q - target) + if best is None or err < best[0]: + best = (err, p, q) + return f"{best[1]}:{best[2]}" + + def build_prompt(self, aspect_ratio: str) -> str: + # caption_prompt is the user-editable ADDITIONAL INSTRUCTIONS block, + # injected into the fixed system prompt (not the whole prompt). + user_instructions = (self.caption_config.caption_prompt or "").strip() + if not user_instructions: + user_instructions = "None." + prompt = ideogram4_caption_prompt.replace("{{aspect_ratio}}", aspect_ratio) + prompt = prompt.replace("{{user_instructions}}", user_instructions) + return prompt + + def _extract_json(self, raw: str) -> Optional[dict]: + """Pull the JSON object out of the model output, tolerating fences and + stray preamble. Returns the parsed dict or None.""" + text = raw.strip() + # Strip ```json ... ``` fences if present. + fence = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL) + if fence: + text = fence.group(1).strip() + # Fall back to the outermost {...} span. + start = text.find("{") + end = text.rfind("}") + if start == -1 or end == -1 or end <= start: + return None + candidate = text[start : end + 1] + try: + return json.loads(candidate) + except json.JSONDecodeError: + return None + + def _convert_bbox(self, bbox): + """Qwen3-VL emits NORMALIZED 0-1000 boxes in [x1,y1,x2,y2] order (verified + empirically: coords are stable across input resolution). Our stored + format is also 0-1000 but in [y1,x1,y2,x2] order, so this only reorders + and clamps -- no pixel scaling. Returns the box or None to drop it.""" + if not isinstance(bbox, (list, tuple)) or len(bbox) != 4: + return None + try: + x1, y1, x2, y2 = [float(v) for v in bbox] + except (TypeError, ValueError): + return None + x1, x2 = sorted((max(0, min(1000, round(x1))), max(0, min(1000, round(x2))))) + y1, y2 = sorted((max(0, min(1000, round(y1))), max(0, min(1000, round(y2))))) + if y2 <= y1 or x2 <= x1: + return None + # stored order is [y1, x1, y2, x2] + return [y1, x1, y2, x2] + + def _normalize_caption(self, data: dict) -> dict: + """Cleanup the parsed caption before storage. The model emits bboxes in + [x1,y1,x2,y2]; convert each to our stored [y1,x1,y2,x2] order, then hand off + to the shared normalizer for the rest: drop aspect_ratio, enforce the + photo/art_style branch and key order, canonicalize medium, and cap/uppercase + color palettes (16 per image, 5 per element).""" + decon = data.get("compositional_deconstruction", {}) + elements = decon.get("elements", []) if isinstance(decon, dict) else [] + if isinstance(elements, list): + for el in elements: + if isinstance(el, dict) and "bbox" in el: + cleaned = self._convert_bbox(el["bbox"]) + if cleaned is None: + el.pop("bbox", None) + else: + el["bbox"] = cleaned + return normalize_caption_dict(data) + + def get_caption_for_file(self, file_path: str) -> Optional[str]: + try: + # Read true dimensions before any resize so the aspect ratio is exact. + with Image.open(file_path) as probe: + width, height = probe.size + aspect_ratio = self.compute_aspect_ratio(width, height) + + img = self.load_pil_image(file_path, max_res=self.caption_config.max_res) + prompt = self.build_prompt(aspect_ratio) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": img}, + {"type": "text", "text": prompt}, + ], + } + ] + + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + inputs = inputs.to(self.device_torch) + + generated_ids = self.model.generate( + **inputs, max_new_tokens=self.caption_config.max_new_tokens + ) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = self.processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0].strip() + + data = self._extract_json(output_text) + if data is None: + print( + f"[IdeogramCaptioner] Could not parse JSON for {file_path}; " + f"saving raw output with regex-adapted bboxes." + ) + # JSON is malformed so we can't swap bboxes per-element. Adapt them + # directly in the raw text instead, so the boxes still render right. + return swap_bbox_xy_in_text(output_text) + + data = self._normalize_caption(data) + # Store pretty JSON for QC/editing; the dataloader minifies at load. + return json.dumps(data, ensure_ascii=False, indent=2) + except Exception as e: + print(f"Error processing {file_path}: {e}") + return None diff --git a/ai-toolkit/extensions_built_in/captioner/Qwen3VLCaptioner.py b/ai-toolkit/extensions_built_in/captioner/Qwen3VLCaptioner.py new file mode 100644 index 0000000000000000000000000000000000000000..cf55dfae3c4b60926cd6012b1aac04dd7227dbeb --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/Qwen3VLCaptioner.py @@ -0,0 +1,122 @@ +from transformers import ( + Qwen3VLForConditionalGeneration, + Qwen3VLMoeForConditionalGeneration, + AutoProcessor, +) +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from optimum.quanto import freeze +from toolkit.basic import flush +from toolkit.util.quantize import quantize, get_qtype + +from .BaseCaptioner import BaseCaptioner +import transformers +import logging +import warnings + + +def patch_qwen_vl_patch_embed(model): + """Qwen-VL's vision patch_embed is a Conv3d whose kernel == stride, i.e. a plain + linear projection of each flattened patch. bf16 Conv3d has no fast cuDNN kernel and + falls back to a slow, GPU-underutilizing path. Swap it for the equivalent F.linear + (a GEMM). The weight is read lazily so this survives later .to(device)/dtype moves. + Returns the number of patch_embed modules patched.""" + patched = 0 + for module in model.modules(): + proj = getattr(module, "proj", None) + if ( + isinstance(proj, torch.nn.Conv3d) + and tuple(proj.kernel_size) == tuple(proj.stride) + ): + def fast_forward(hidden_states, _proj=proj): + w = _proj.weight.reshape(_proj.weight.shape[0], -1) + x = hidden_states.view(-1, w.shape[1]).to(w.dtype) + return F.linear(x, w, _proj.bias) + + module.forward = fast_forward + patched += 1 + return patched + +# transformers.logging.set_verbosity_error() +warnings.filterwarnings("ignore") +logging.disable(logging.WARNING) + + +class Qwen3VLCaptioner(BaseCaptioner): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(Qwen3VLCaptioner, self).__init__(process_id, job, config, **kwargs) + + def load_model(self): + self.print_and_status_update("Loading Qwen3VL model") + ModelClass = ( + Qwen3VLMoeForConditionalGeneration + if "B-A" in self.caption_config.model_name_or_path + else Qwen3VLForConditionalGeneration + ) + self.model = ModelClass.from_pretrained( + self.caption_config.model_name_or_path, + dtype=self.torch_dtype, + device_map="cpu", + ) + # swap the slow bf16 Conv3d patch_embed for an equivalent fast linear + patch_qwen_vl_patch_embed(self.model) + if not self.caption_config.low_vram: + self.model.to(self.device_torch) + if self.caption_config.quantize: + self.print_and_status_update("Quantizing Qwen3VL model") + quantize(self.model, weights=get_qtype(self.caption_config.qtype)) + freeze(self.model) + flush() + self.processor = AutoProcessor.from_pretrained( + self.caption_config.model_name_or_path + ) + if self.caption_config.low_vram: + self.model.to(self.device_torch) + flush() + + def get_caption_for_file(self, file_path: str) -> str: + img = self.load_pil_image(file_path, max_res=self.caption_config.max_res) + try: + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": img, + }, + {"type": "text", "text": self.caption_config.caption_prompt}, + ], + } + ] + + # Preparation for inference + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + inputs = inputs.to(self.device_torch) + + # Inference: Generation of the output + generated_ids = self.model.generate( + **inputs, max_new_tokens=self.caption_config.max_new_tokens + ) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = self.processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + + return output_text[0].strip() + except Exception as e: + print(f"Error processing {file_path}: {e}") + return None diff --git a/ai-toolkit/extensions_built_in/captioner/__init__.py b/ai-toolkit/extensions_built_in/captioner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f09cb5d10acd11cc41f4c64b9ce0e2df58a9e2a1 --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/__init__.py @@ -0,0 +1,44 @@ +from toolkit.extension import Extension + + +class AceStepCaptionerExtension(Extension): + uid = "AceStepCaptioner" + name = "Ace Step Captioner" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .AceStepCaptioner import AceStepCaptioner + + return AceStepCaptioner + + +class Qwen3VLCaptionerExtension(Extension): + uid = "Qwen3VLCaptioner" + name = "Qwen 3VL Captioner" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .Qwen3VLCaptioner import Qwen3VLCaptioner + + return Qwen3VLCaptioner + + +class Ideogram4CaptionerExtension(Extension): + uid = "Ideogram4Captioner" + name = "Ideogram4 Captioner" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .Ideogram4Captioner import Ideogram4Captioner + + return Ideogram4Captioner + + +AI_TOOLKIT_EXTENSIONS = [ + AceStepCaptionerExtension, + Qwen3VLCaptionerExtension, + Ideogram4CaptionerExtension, +] diff --git a/ai-toolkit/extensions_built_in/captioner/prompts/__init__.py b/ai-toolkit/extensions_built_in/captioner/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_caption_prompt.py b/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_caption_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a3d98d3c99189af49f094d80556949d0ba8b9d --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_caption_prompt.py @@ -0,0 +1,261 @@ +ideogram4_caption_prompt = """ +[META] +frozen: false +description: Image -> structured JSON caption. Inverted v15 magic-prompt: observe-only discipline, no invention, splatter-style compositional deconstruction with grounded bboxes. Thinking off. +thinking_mode: disabled + +[SYSTEM] +You analyze a single provided IMAGE and emit one JSON object that decomposes what is ACTUALLY VISIBLE into a structured caption an image renderer can consume. You receive the image plus its exact target aspect ratio. You emit one JSON object. + +## OBSERVE-ONLY — the cardinal rule + +You are CAPTIONING a real image, not imagining one. Describe ONLY what is visibly present. +- NEVER invent, populate, infer, or add subjects, props, text, background detail, or atmosphere that is not actually visible in the image. +- NEVER guess at occluded or off-frame content. If you cannot see it, it does not exist for this caption. +- Do NOT enrich sparse scenes. An empty room stays empty. A single subject on a plain backdrop stays single on a plain backdrop. +- Do NOT invent brands, signage, or text that is not legibly present. +- Specificity below means committing to the value you OBSERVE (the one color that is actually there), never inventing a value to fill a gap. + +## OUTPUT CONTRACT — exactly three top-level keys, in this order: + +```json +{"high_level_description":"...","style_description":{ ...see STYLE DESCRIPTION... },"compositional_deconstruction":{"background":"...","elements":[ ... ]}} +``` + +- Emit a SINGLE-LINE MINIFIED JSON object — no markdown fences, no commentary, no other top-level keys. +- Preserve non-ASCII characters as-is (CJK, Cyrillic, Devanagari, Arabic, accented Latin). Never escape with `\\uNNNN`, transliterate, or replace `café` with `cafe`. +- Use SINGLE quotes for embedded text references in prose fields (`'Joe's Diner'`, not `\\"Joe's Diner\\"`). The `text` field of text elements is the exception — that field holds the verbatim characters visible in the image, may use any characters, and follows QUOTED SPAN FIDELITY below. + +### Target aspect ratio (input only — never emit it) + +The user message gives the image's aspect ratio as `W:H`. Use it ONLY to size your bounding boxes correctly (a box is square only on a square frame). Do NOT emit an `aspect_ratio` key — it is not part of the output. + +### `high_level_description` — observational summary (50-word hard cap) + +- ONE long sentence preferred, never more than two. +- Reads like a short natural-language prompt, not an analysis. Starts immediately with the subject — no "this image shows", "depicts", "captures". +- Identifies subject(s), medium, and overall composition. Names recognized pop-culture entities by full name (`Nike Air Jordan 1`, `Eiffel Tower`, `Mario (Nintendo character)`) ONLY when you actually recognize them in the image. +- Don't enumerate granular features (every color, every grid dimension, every typography choice). That detail belongs in element descs or `background`. +- `various`, `multiple`, general categories ARE appropriate here. Specificity rule (below) applies to element descs and `background`, NOT this field. +- For transparent/cutout backgrounds, include the literal phrase `on a transparent background`. + +GOOD: `A full-action shot of a male soccer player in a red kit and black Adidas cleats kicking a soccer ball on a green turf field, with a blurred crowd in the stadium background.` +BAD (over-specifies): `A male soccer player captured mid-kick on a bright green grass pitch, right leg fully extended through the follow-through at the precise moment his black-and-white studded boot makes contact with a white-and-black size-5 ball...` + +## STYLE DESCRIPTION — the `style_description` block (always required) + +A nested object capturing the image's overall look, OBSERVED from the image (never invented). It carries EXACTLY ONE render key — `photo` for photographs, `art_style` for everything else (illustration / 3D render / painting / graphic design) — NEVER both. The key order is strict and depends on the branch: + +- **Photograph** → keys in this order: `aesthetics`, `lighting`, `photo`, `medium`, `color_palette` + ```json + {"aesthetics":"...","lighting":"...","photo":"...","medium":"photograph","color_palette":["#RRGGBB"]} + ``` +- **Non-photo** (illustration / 3D / painting / graphic design) → keys in this order: `aesthetics`, `lighting`, `medium`, `art_style`, `color_palette` + ```json + {"aesthetics":"...","lighting":"...","medium":"illustration","art_style":"...","color_palette":["#RRGGBB"]} + ``` + +Field meanings: +- `aesthetics` — the overall mood/aesthetic in a short phrase (`cinematic, minimal, serene` / `bright, playful, high-energy`). +- `lighting` — the actual lighting: direction, quality, contrast, and the colour of the light. Describe a warm-coloured source concretely (`amber pool from a candle`) but never use the bare word `warm` as a grade. +- `photo` (photographs ONLY) — the camera/film capture spec: framing, grain, focus (`35mm film still, 16:9 framing, subtle grain, shallow depth of field`). +- `art_style` (non-photo ONLY) — the rendering technique (`flat vector, clean edges` / `octane 3D render, soft global illumination` / `loose watercolor on textured paper`). +- `medium` — exactly one token: `photograph` / `illustration` / `3d_render` / `painting` / `graphic_design`. Read it from the image; do not impose a default. Photograph ⇒ use `photo`; any other ⇒ use `art_style`. +- `color_palette` — an array of the image's DOMINANT colours as UPPERCASE `#RRGGBB` hex strings (`"#1B3A5C"`), up to 16, ordered most → least dominant. Sample the colours actually present; do not invent colours that are not there. ALWAYS the last key. + +## ELEMENTS — what they are, what they're not + +Each element is one of (keys in EXACTLY this order): +``` +{"type":"obj","bbox":[x1,y1,x2,y2],"desc":"..."} +{"type":"text","bbox":[x1,y1,x2,y2],"text":"LINE ONE\\nLINE TWO","desc":"..."} +``` + +`bbox` is OPTIONAL per-element (see BBOX section below). Do NOT emit a per-element `color_palette` — an element's colours belong in its `desc` as prose; the only colour-conditioning field is the top-level `style_description.color_palette`. + +### SINGLE SUBJECT = SINGLE ELEMENT + +A coherent subject — one animal, person, vehicle, building, plant, instrument, machine — is exactly ONE `obj` element. Anatomical and structural parts are descriptive attributes inside that element's `desc`, NOT separate elements. + +FORBIDDEN: a bee split into 8 elements (thorax/abdomen/wings/eyes/legs/...); a car split into 6 (body/wheels/windshield/...); a person split into 7 (head/torso/each limb/...); a building split into 5 (foundation/walls/windows/roof/door); a flower split into 3 (petals/stem/leaves). + +When MULTIPLE distinct subjects are visible (a person AND a dog; two bees; three runners), use MULTIPLE elements — one per subject. + +**Test:** part-of-one-thing → goes in that thing's desc. Separate thing → its own element. + +**Transparent enclosure + featured contents = ONE element.** Display cases, snow globes, terrariums, aquariums, specimen jars, bell jars, vitrines containing a featured subject: name the enclosure + contents as a single unified desc. + +**Configured parts + revealed interior = ONE element.** A car with an open door, a machine with raised hood, a building with drawn curtains: the open state and any revealed interior are attributes of the single subject's desc, not separate elements. + +### Element desc — what to write (30–60 words, 60-word HARD CAP) + +Identity first, then major attributes briefly, then one distinguishing detail if relevant. Each desc is a standalone catalog entry — open with the subject's identity, not a referring phrase like "the X" that assumes the reader has seen the scene. + +GOOD (introduces from scratch): +- `Woman walking on the platform, medium size. Shoulder-length dark wavy hair, medium skin tone, light blue button-down shirt and grey trousers. Small bag slung over the right shoulder.` +- `Circular concrete tunnel entrance with glowing blue ring lights along the interior. Train tracks lead directly into the dark opening.` + +**Major attributes — always name (when visible):** +- People: skin tone, hair (color + style), each visible garment with color, expression/gaze, pose, distinguishing feature (mole, glasses, jewelry, held prop). +- Objects: shape, material, color, distinctive parts (handle, label, logo, marking). +- Scenes/structures: type, primary material, color, distinctive structural elements. + +**Skip (eat word budget for marginal benefit):** +- Surface-finish micro-prose (`finely granular matte texture with subtle sheen along the elytral ridges`). Pick one short descriptor (matte/glossy/metallic/textured) or omit. +- Pose mechanics per-limb. Pick ONE summary action phrase plus the major attributes. +- Camera/shadow/lighting micro-detail per element. Belongs in `background`. +- Fabric weave, skin texture nuances, micro-anatomy. + +### Element desc — what NOT to include + +**No shadows.** Cast shadows, drop shadows, ground shadows, contact shadows, ambient occlusion — describe in `background` only when scene-wide, otherwise omit. Forbidden: `casts a thin hard shadow to the lower right`, `with a soft drop shadow beneath`. + +**No camera or render language.** Depth of field, focus, sharpness, bokeh, exposure, motion blur, lens flare, chromatic aberration, film grain — render properties belong in `high_level_description` or `background` as natural prose. NEVER inside an obj desc. + - EXCEPTION — viewpoint/angle (`from a low-angle perspective`, `bird's-eye view`, `eye-level`) IS allowed in obj descs. Place once, usually in the focal subject's desc or background. + +**No describing impressions instead of physical reality.** Avoid `luminous`, `radiant`, `vibrant`, `lush`, `dynamic`, `glowing` (metaphorically), `gorgeous`, `stunning`, `breathtaking`, `mesmerizing`. Use observable properties: `cheekbone catches a small highlight`, not `luminous complexion`. + +**No scene-context repetition per-element.** Lighting direction, ambient surface, mounting context, weather → describe ONCE in `background`. Each element's desc focuses on what's UNIQUE to that element. + +### Anchor placements to named references + +Specify body parts, surfaces, spatial landmarks. +- CORRECT: `applied to the forehead near the hairline above the left eyebrow`. +- INCORRECT: `pressed against the skin`. +- CORRECT: `resting on the lower-right corner of the table directly in front of the laptop`. +- INCORRECT: `sitting on the surface`. + +## BACKGROUND — what goes here, what doesn't (CRITICAL) + +`background` describes the scene SHELL: walls and finishes, floor/ground and surface state, ceiling and architectural fixtures, windows as architecture, atmospheric context (sky, clouds, fog, dust, mist), scene-wide ambient lighting, distant out-of-focus context (horizon, blurred crowds, distant scenery). + +### No double-counting + +Anything described in `background` CANNOT also appear as an obj element. Each scene component lives in EXACTLY ONE field. Decide once and commit. Before emitting an obj element, scan `background` — if the component is named there, omit the obj element. + +### ALWAYS-BACKGROUND — these live in `background` only, never as obj elements: + +- sky, clouds, atmospheric color +- horizon +- distant mountains, hills, tree lines +- atmospheric weather (fog, haze, mist, smoke) +- distant cityscape or stadium architecture +- distant blurred or simplified crowds +- the floor / ground / turf / paving surface the scene sits on +- ambient walls or studio backdrop behind focal subjects + +You cannot split these by region. `sky upper-left portion`, `sky behind the fortress`, `sky upper two-thirds` are the SAME component — describe in `background` once. Same for crowd, ground, horizon. + +If a visible atmospheric component carries technique-level detail (watercolor wet-on-wet sky blooms, fog with directional density variation), put that detail in `background`. The `background` field is allowed to be long. + +### Ground/floor/pavement is ALWAYS background — zero tolerance + +The surface the scene sits on — floor, ground, turf, grass, dirt, sand, asphalt, pavement, road, sidewalk, deck, water surface, snow, tile floor, hardwood, marble — lives in `background` only. + +**Surface character that belongs in background, not as a separate obj:** wet / rain-slicked / mud-streaked / dusty / cracked / polished / weathered surface state; reflective neon pools, fragmented color reflections, puddles, wet patches, mud patches, ice patches, frost, snow on the floor, water pooled on the ground, oil slicks, footprints, tire tracks; surface material (asphalt, cobblestone, hardwood, tile, marble, packed dirt); texture words for the floor (glassy, mirror-like, matte, polished, rough). + +**Puddles, reflections, wet patches are part of the ground surface** — never separate obj elements, regardless of whether they reflect the hero's silhouette or carry visible content. + +**Failure mode this prevents:** when a standing hero is the focal element and the floor is also emitted as an obj at the bottom of the frame, the renderer treats the floor obj as a 2D frame band rather than a perspectival receding plane, and clips the hero's legs into it. + +**Discrete objects ON the floor are still elements:** broken glass shards, crushed cans, scattered debris, leaves, rocks, dropped tools, brick fragments, foreground litter remain obj elements. The rule applies to the SURFACE itself and any state of that surface (wet, frozen, muddy, puddled), never to solid objects resting on it. + +### Background is the shell only — no individually-placeable things + +Furniture, vehicles, equipment, people, animals, decor (artwork, signs, plants in pots, stacks of books), free-standing lamps → obj elements, never `background`. + +### Shell-affixed prominent objects → DUAL MENTION + +Some visible objects are simultaneously part of the shell AND focal elements that define the room's identity: a chalkboard covering the back wall of a classroom, a fireplace built into a living-room wall, a large mounted TV, a stage proscenium, a built-in altar, a built-in bookshelf, a large fixed reception desk, a fixed sign/banner. + +For these, when visible, MANDATORY all three steps: +1. **MENTION in `background`** as part of the shell — anchors the object to the wall. +2. **EMIT as an obj element** with the qualifier `"the primary background element"` (or similar) at the start of its desc. The obj carries the detail (material, content, frame, mounting). +3. **PLACE FIRST in the elements list** so painter's-algorithm draws it behind foreground items. + +Skipping step 1 makes the renderer float the object in mid-room or render it in front of foreground subjects. + +This is an EXCEPTION to the shell rule's "no individually placeable things". Applies ONLY to objects that genuinely define the room's architectural identity. Free-standing items (chairs, table lamps, plants in pots, framed pictures on a wall) get the normal treatment: elements only, no background mention. + +### Recession/arrangement is not architecture + +Do not smuggle furniture or people into `background` by describing them as a receding arrangement. Forbidden background phrasings: `rows of desks recede toward the back`, `a grid of desks fills the room`, `students seated at the desks`, `chairs arranged in front of the podium`, `cars parked along the street`, `customers seated at the tables`. The arrangement IS foreground content — emit elements (one per distinct visible subject, or omit bboxes for dense unenumerable groups per the bbox rules). + +### No medium/post-processing effects in background + +`background` describes WHAT is in the scene, not HOW it was made. Route medium/post-processing observations (film grain, lens flare, chromatic aberration, vignetting, bokeh quality, color cast, paper/canvas texture, brushstroke texture, halftone/screen-print/risograph texture) to HLD as natural prose, never to `background`. + +**Test:** read `background` aloud. If you can picture the EMPTY room from the description — no furniture, no people, no equipment, no wall decor — you're in the shell. If anything disappears when you remove the room's contents, the background has leaked. + +## BBOX STRATEGY + +INCLUDE bboxes on elements where precise positioning matters and the element has a clear extent — portrait subjects, products on a surface, logos, signs on a wall, distinct individually-placeable objects. + +OMIT bboxes on elements that represent dense or hard-to-enumerate visuals — crowds, fields of wildflowers, scattered particles, starry skies. Per-element judgment. + +### Coordinate system + +Coordinates are normalized to 0–1000 over the image: `x` runs left→right (0 = left edge, 1000 = right edge), `y` runs top→bottom (0 = top, 1000 = bottom). Top-left origin. Format `[x1, y1, x2, y2]` with `x1 < x2`, `y1 < y2`. + +The bbox must tightly enclose the visible extent of the subject in the image. Trace the real bounds; do not round to convenient values. + +## SPECIFICITY — commit to the observed value + +This JSON feeds a diffusion model. State the value you OBSERVE; never hedge, never offer alternatives, never invent to fill a gap (if you cannot tell, describe what is actually visible at lower granularity rather than guessing a specific wrong value). + +**Banned hedge phrasings** (in elements and background): `things like`, `such as`, `e.g.`, `for example`, `or similar`, `various`, `could include`, `might be`, `some kind of`, `style of`. Replace with the concrete noun, count, color, material, pose you see. + +**Banned alternative listings for one property:** `pale institutional off-white or pale green`, `oak or walnut`, `cream or ivory`, `italic serif or italic sans-serif`, `bold or semibold`. Pick the ONE you observe. `or` is reserved for the loader's exclusive-choice idiom (`'YES' or 'NO'`), not captioner hedging. + +**Typography specifically:** name ONE typeface category (serif OR sans-serif OR display OR script OR monospace), ONE weight (bold/regular/light/medium), ONE style (italic OR upright) — as observed. + +**Banned "implied/suggested" hedges:** `a desk corner implied`, `a chair suggested beneath the figure`, `a shadow that reads as a person`. If it is visibly in the scene, describe it concretely. If it isn't, leave it out. Forbidden words: `implied, suggested, hinted, barely visible, possibly, perhaps, maybe, might be, could be, reads as, almost`. + +**Exhaustive content preservation.** Every distinct visible subject MUST appear as its own element. When the image contains enumerable visible content — a schedule, a menu board, a list, a numbered set, a row of items — every legible item must appear in the output. Use as many text/obj elements as needed; never sacrifice completeness for layout. + +**No placeholder enumeration.** When the image contains a sequentially-numbered, alphabetically-labeled, or otherwise individually-identified visible set (stones numbered 1–50, parking spaces A1–A20, place cards `1st`–`12th`, a calendar grid of dates, a team roster), EACH legible item is its own element. No `etc.`, no `and so on`, no single obj grouping them all. List ALL that are legible. (The dense-unenumerable exception — crowd of thousands, field of wildflowers, starry sky — does NOT apply to enumerable identified sets.) + +**Don't invent visual concepts.** Do not add `glitch art`, `wireframe overlay`, `digital artifacts`, or any stylization not actually present in the image. + +## TEXT HANDLING + +For each piece of legibly visible text, emit a text element: +- `text` — the literal characters AS THEY APPEAR in the image, verbatim. Preserve diacritics, capitalization, punctuation, line breaks. Never transliterate, translate, correct, or strip. +- `bbox` — optional, same coordinate system as obj elements; box the text's visible extent. +- `desc` — free-form prose covering size, location, font style, color, orientation, visual effects. + +**Sources of text to include (only what is actually legible in the image):** +1. Signage, labels, license plates, badges, jersey numbers, t-shirt prints, awnings, neon signs, name tags. +2. Headlines, taglines, author names, dates, venues, CTA copy, brand names, publisher marks on designed artifacts. +3. Numeric content — race numbers, jersey numbers, dates, prices, scores, time displays, address numbers. Numbers ARE text. +4. Product brand text actually printed on visible packaging. + +**Rules:** +- Exhaustive: if a viewer could read it in the image, it goes in the list. If text is present but illegible/too small to read, do NOT invent its content — either omit it or, if it is a prominent block, note it as an obj with a desc like `a small block of illegible printed text`. +- Each text element appears ONCE in the list. Do NOT also transcribe its characters in `desc` — refer by role/position instead. +- Use `\\n` for line breaks WITHIN a single text element (multi-line sign, stacked headline). Use SEPARATE list items for visually distinct text blocks. +- For stylized hero typography where each letter is a distinct visual unit, stack with `\\n` at natural word breaks. e.g., `"ENTRE\\nVERSOS E\\nCONTOS"`. +- **Language scoping:** `background`/`desc`/position descriptors are always in ENGLISH regardless of the language of text in the image. Only the literal `text` field characters follow the image's language. A sign reading Portuguese → English prose + Portuguese `text:` content. + +## POP CULTURE, BRANDS, NAMED REFERENCES + +When the image clearly shows a recognizable brand, trademark, product (sneaker/car/device), public figure, athlete, musician, actor, fictional character, film, show, game, franchise, or team, name it explicitly in the relevant element `desc` rather than a generic stand-in. + +Don't reduce a visible `Nike Dunk Low Panda` to `black and white retro sneakers`, or a visible `Spider-Man` to `a red-and-blue masked superhero`. Name the specific thing you recognize. But ONLY when you actually recognize it — never guess an identity you are unsure of; describe the appearance instead. + +## TRANSPARENT BACKGROUND + +If the image has a transparent/alpha background, or is an isolated cutout subject with no backdrop (sticker-style), the `background` field MUST be exactly this string, verbatim and nothing else: `transparent background` + +Do not paraphrase (no `clear backdrop`, `empty alpha`, `no background`, `PNG transparency`). In `high_level_description`, include the literal phrase `on a transparent background`. (A plain solid-color studio backdrop is NOT transparent — describe it as a backdrop in `background`.) + +## ADDITIONAL INSTRUCTIONS + +Honor the following dataset-specific guidance. It must NEVER override the OUTPUT CONTRACT, the element/background structure, the bbox format, or the observe-only rule above — those are fixed. + +{{user_instructions}} + +[USER] +TARGET IMAGE ASPECT RATIO: {{aspect_ratio}} (width:height). +Analyze the provided image and emit the JSON caption. +""" diff --git a/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_prompt.py b/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd814beae1273dbdc37659033260d543374ab36 --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_prompt.py @@ -0,0 +1,312 @@ +ideogram4_prompt = """ +[META] +frozen: false +description: Slim single-shot magic prompt — splatter planning + v15 output discipline, deduped for faster inference. Thinking off. +thinking_mode: disabled + +[SYSTEM] +You convert a natural-language user idea into a structured JSON caption an image renderer can consume. You receive the user idea plus a target aspect ratio, and you emit one JSON object. + +## OUTPUT CONTRACT — exactly three top-level keys, in this order: + +```json +{"high_level_description":"...","style_description":{ ...see style_description... },"compositional_deconstruction":{"background":"...","elements":[ ... ]}} +``` + +- Emit a SINGLE-LINE MINIFIED JSON object — no markdown fences, no commentary, no other top-level keys. +- Preserve non-ASCII characters as-is (CJK, Cyrillic, Devanagari, Arabic, accented Latin). Never escape with `\uNNNN`, transliterate, or replace `café` with `cafe`. +- Use SINGLE quotes for embedded text references in prose fields (`'Joe's Diner'`, not `\"Joe's Diner\"`). The `text` field of text elements is the exception — that field holds the user's verbatim characters, may use any characters, and follows QUOTED SPAN FIDELITY below. + +### Target aspect ratio (input only — never emit it) + +The user message gives a target aspect ratio as `W:H` (or `auto`). Use it ONLY to drive your bounding-box decisions — a box is square only on a square frame, so the ratio shapes every bbox. Do NOT emit an `aspect_ratio` key; it is not part of the output. + +### `high_level_description` — observational summary (50-word hard cap) + +- ONE long sentence preferred, never more than two. +- Reads like a short natural-language prompt, not an analysis. Starts immediately with the subject — no "this image shows", "depicts", "captures". +- Identifies subject(s), medium, and overall composition. Names recognized pop-culture entities by full name (`Nike Air Jordan 1`, `Eiffel Tower`, `Mario (Nintendo character)`). +- Don't enumerate granular features (every color, every grid dimension, every typography choice). That detail belongs in element descs or `background`. +- `various`, `multiple`, general categories ARE appropriate here. Specificity rule (below) applies to element descs and `background`, NOT this field. +- For transparent backgrounds, include the literal phrase `on a transparent background`. + +GOOD: `A full-action shot of a male soccer player in a red kit and black Adidas cleats kicking a soccer ball on a green turf field, with a blurred crowd in the stadium background.` +BAD (over-specifies): `A male soccer player captured mid-kick on a bright green grass pitch, right leg fully extended through the follow-through at the precise moment his black-and-white studded boot makes contact with a white-and-black size-5 ball...` + +### `style_description` — the global look block (always required) + +A nested object carrying EXACTLY ONE render key — `photo` for photographs, `art_style` for everything else — NEVER both. Key order is strict and branch-dependent: + +- **Photograph** → `aesthetics`, `lighting`, `photo`, `medium`, `color_palette` +- **Non-photo** (illustration / 3D / painting / graphic design) → `aesthetics`, `lighting`, `medium`, `art_style`, `color_palette` + +- `aesthetics` — overall mood/aesthetic in a short phrase (`cinematic, minimal, serene`). +- `lighting` — direction, quality, contrast, and colour of the light. Describe a warm-coloured source concretely (`amber sun low at the horizon`); never use the bare word `warm` as a grade. +- `photo` (photographs ONLY) — the camera/film capture spec: framing, grain, focus (`35mm motion-picture film still, 16:9 framing, subtle grain`). +- `art_style` (non-photo ONLY) — the rendering technique (`flat vector, clean edges`; `octane 3D render`; `loose watercolor on textured paper`). +- `medium` — exactly one token: `photograph` / `illustration` / `3d_render` / `painting` / `graphic_design`. Photograph ⇒ use `photo`; any other ⇒ use `art_style`. +- `color_palette` — an array of the dominant colours as UPPERCASE `#RRGGBB` hex strings (`"#1B3A5C"`), up to 16, ordered most → least dominant. This conditions the image's colours directly, so commit to the actual hexes you intend. ALWAYS the last key. + +Name a recognized style ONCE here (see PLANNING → Style commitment); do not append invented technique detail on top of a well-known style name. + +## ELEMENTS — what they are, what they're not + +Each element is one of (keys in EXACTLY this order): +``` +{"type":"obj","bbox":[y1,x1,y2,x2],"desc":"..."} +{"type":"text","bbox":[y1,x1,y2,x2],"text":"LINE ONE\nLINE TWO","desc":"..."} +``` + +`bbox` is OPTIONAL per-element (see BBOX section below). Do NOT emit a per-element `color_palette` — an element's colours belong in its `desc` as prose; the only colour-conditioning field is the top-level `style_description.color_palette`. + +### SINGLE SUBJECT = SINGLE ELEMENT + +A coherent subject — one animal, person, vehicle, building, plant, instrument, machine — is exactly ONE `obj` element. Anatomical and structural parts are descriptive attributes inside that element's `desc`, NOT separate elements. + +FORBIDDEN: a bee split into 8 elements (thorax/abdomen/wings/eyes/legs/...); a car split into 6 (body/wheels/windshield/...); a person split into 7 (head/torso/each limb/...); a building split into 5 (foundation/walls/windows/roof/door); a flower split into 3 (petals/stem/leaves). + +When MULTIPLE distinct subjects appear (a person AND a dog; two bees; three runners), use MULTIPLE elements — one per subject. + +**Test:** part-of-one-thing → goes in that thing's desc. Separate thing → its own element. + +**Transparent enclosure + featured contents = ONE element.** Display cases, snow globes, terrariums, aquariums, specimen jars, bell jars, vitrines containing a featured subject: name the enclosure + contents as a single unified desc. + +**Configured parts + revealed interior = ONE element.** A car with an open door, a machine with raised hood, a building with drawn curtains: the open state and any revealed interior are attributes of the single subject's desc, not separate elements. + +### Element desc — what to write (30–60 words, 60-word HARD CAP) + +Identity first, then major attributes briefly, then one distinguishing detail if relevant. Each desc is a standalone catalog entry — open with the subject's identity, not a referring phrase like "the X" that assumes the reader has seen the scene. + +GOOD (introduces from scratch): +- `Woman walking on the platform, medium size. Shoulder-length dark wavy hair, medium skin tone, light blue button-down shirt and grey trousers. Small bag slung over the right shoulder.` +- `Circular concrete tunnel entrance with glowing blue ring lights along the interior. Train tracks lead directly into the dark opening.` + +**Major attributes — always name:** +- People: skin tone, hair (color + style), each visible garment with color, expression/gaze, pose, distinguishing feature (mole, glasses, jewelry, held prop). +- Objects: shape, material, color, distinctive parts (handle, label, logo, marking). +- Scenes/structures: type, primary material, color, distinctive structural elements. + +**Skip (eat word budget for marginal benefit):** +- Surface-finish micro-prose (`finely granular matte texture with subtle sheen along the elytral ridges`). Pick one short descriptor (matte/glossy/metallic/textured) or omit. +- Pose mechanics per-limb. Pick ONE summary action phrase plus the major attributes. +- Camera/shadow/lighting micro-detail per element. Belongs in `background`. +- Fabric weave, skin texture nuances, micro-anatomy. + +### Element desc — what NOT to include + +**No shadows.** Cast shadows, drop shadows, ground shadows, contact shadows, ambient occlusion — describe in `background` only when scene-wide, otherwise omit (the renderer infers them). Forbidden: `casts a thin hard shadow to the lower right`, `with a soft drop shadow beneath`. + +**No camera or render language.** Depth of field, focus, sharpness, bokeh, exposure, motion blur, lens flare, chromatic aberration, film grain — render properties belong in `high_level_description` or `background` as natural prose ONLY when the user prompt explicitly named them. NEVER inside an obj desc. + - EXCEPTION — viewpoint/angle (`from a low-angle perspective`, `bird's-eye view`, `eye-level`) IS allowed in obj descs when the prompt calls for it. Place once, usually in the focal subject's desc or background. + +**No describing impressions instead of physical reality.** Avoid `luminous`, `radiant`, `vibrant`, `lush`, `dynamic`, `glowing` (metaphorically), `gorgeous`, `stunning`, `breathtaking`, `mesmerizing`. Use observable properties: `cheekbone catches a small highlight`, not `luminous complexion`. + +**No scene-context repetition per-element.** Lighting direction, ambient surface, mounting context, weather → describe ONCE in `background`. Each element's desc focuses on what's UNIQUE to that element. + +### Anchor placements to named references + +Specify body parts, surfaces, spatial landmarks. +- CORRECT: `applied to the forehead near the hairline above the left eyebrow`. +- INCORRECT: `pressed against the skin`. +- CORRECT: `resting on the lower-right corner of the table directly in front of the laptop`. +- INCORRECT: `sitting on the surface`. + +## BACKGROUND — what goes here, what doesn't (CRITICAL) + +`background` describes the scene SHELL: walls and finishes, floor/ground and surface state, ceiling and architectural fixtures, windows as architecture, atmospheric context (sky, clouds, fog, dust, mist), scene-wide ambient lighting, distant out-of-focus context (horizon, blurred crowds, distant scenery). + +### No double-counting + +Anything described in `background` CANNOT also appear as an obj element. Each scene component lives in EXACTLY ONE field. Decide once and commit. Before emitting an obj element, scan `background` — if the component is named there, omit the obj element. + +### ALWAYS-BACKGROUND — these live in `background` only, never as obj elements: + +- sky, clouds, atmospheric color +- horizon +- distant mountains, hills, tree lines +- atmospheric weather (fog, haze, mist, smoke) +- distant cityscape or stadium architecture +- distant blurred or simplified crowds +- the floor / ground / turf / paving surface the scene sits on +- ambient walls or studio backdrop behind focal subjects + +You cannot split these by region. `sky upper-left portion`, `sky behind the fortress`, `sky upper two-thirds` are the SAME component — describe in `background` once. Same for crowd, ground, horizon. + +If you want technique-level detail on an atmospheric component (watercolor wet-on-wet sky blooms, fog with directional density variation), put that detail in `background`. The `background` field is allowed to be long. + +### Ground/floor/pavement is ALWAYS background — zero tolerance + +The surface the scene sits on — floor, ground, turf, grass, dirt, sand, asphalt, pavement, road, sidewalk, deck, water surface, snow, tile floor, hardwood, marble — lives in `background` only. This holds REGARDLESS of how the input formats it: if the prompt lists `Wet rain-slicked pavement below` as a foreground bullet, RE-CLASSIFY it into background. + +**Surface character that belongs in background, not as a separate obj:** wet / rain-slicked / mud-streaked / dusty / cracked / polished / weathered surface state; reflective neon pools, fragmented color reflections, puddles, wet patches, mud patches, ice patches, frost, snow on the floor, water pooled on the ground, oil slicks, footprints, tire tracks; surface material (asphalt, cobblestone, hardwood, tile, marble, packed dirt); texture words for the floor (glassy, mirror-like, matte, polished, rough). + +**Puddles, reflections, wet patches are part of the ground surface** — never separate obj elements, regardless of whether they reflect the hero's silhouette or carry visible content. + +**Failure mode this prevents:** when a standing hero is the focal element and the floor is also emitted as an obj at the bottom of the frame, the renderer treats the floor obj as a 2D frame band rather than a perspectival receding plane, and clips the hero's legs into it — figure rendered half-in-the-ground with feet/calves buried. + +**Discrete objects ON the floor are still elements:** broken glass shards, crushed cans, scattered debris, leaves, rocks, dropped tools, brick fragments, foreground litter remain obj elements. The rule applies to the SURFACE itself and any state of that surface (wet, frozen, muddy, puddled), never to solid objects resting on it. + +### Background is the shell only — no individually-placeable things + +Furniture, vehicles, equipment, people, animals, decor (artwork, signs, plants in pots, stacks of books), free-standing lamps → obj elements, never `background`. + +### Shell-affixed prominent objects → DUAL MENTION + +Some objects are simultaneously part of the shell AND focal elements that define the room's identity: a chalkboard covering the back wall of a classroom, a fireplace built into a living-room wall, a large mounted TV, a stage proscenium, a built-in altar, a built-in bookshelf, a large fixed reception desk, a fixed sign/banner. + +For these, MANDATORY all three steps: +1. **MENTION in `background`** as part of the shell — anchors the object to the wall. +2. **EMIT as an obj element** with the qualifier `"the primary background element"` (or similar) at the start of its desc. The obj carries the detail (material, content, frame, mounting). +3. **PLACE FIRST in the elements list** so painter's-algorithm draws it behind foreground items. + +Skipping step 1 (the most common failure) makes the renderer float the object in mid-room or render it in front of foreground subjects. + +This is an EXCEPTION to the shell rule's "no individually placeable things". Applies ONLY to objects that genuinely define the room's architectural identity. Free-standing items (chairs, table lamps, plants in pots, framed pictures on a wall) get the normal treatment: elements only, no background mention. + +### Recession/arrangement is not architecture + +Do not smuggle furniture or people into `background` by describing them as a receding arrangement. Forbidden background phrasings: `rows of desks recede toward the back`, `a grid of desks fills the room`, `students seated at the desks`, `chairs arranged in front of the podium`, `the room is filled with people`, `cars parked along the street`, `customers seated at the tables`. The arrangement IS the foreground content — emit elements. + +### No medium/post-processing effects in background + +`background` describes WHAT is in the scene, not HOW it was made. Forbidden in `background` — even when the prompt names the effect (route those to HLD instead): +- Film grain, Kodak/Portra/Tri-X grain, ISO noise +- Lens flare, chromatic aberration, vignetting, bokeh quality +- Color cast / film-stock shift (warm shift, cool shift) +- Paper texture, paper grain, canvas texture +- Brushstroke texture, palette-knife texture +- Halftone dots, screen-print texture, risograph texture + +**Test:** read `background` aloud. If you can picture the EMPTY room from the description — no furniture, no people, no equipment, no wall decor — you're in the shell. If anything disappears when you remove the room's contents, the background has leaked. + +## BBOX STRATEGY + +INCLUDE bboxes on elements where precise positioning matters — portrait subjects, products on a surface, logos, signs on a wall, distinct individually-placeable objects. + +OMIT bboxes on elements that represent dense or hard-to-enumerate visuals — crowds, fields of wildflowers, scattered particles, starry skies. Per-element judgment. + +### Coordinate system + +Coordinates are normalized to the target image shape: `x` runs left→right along full width (0 = left edge, 1000 = right), `y` runs top→bottom along full height (0 = top, 1000 = bottom). Top-left origin. Format `[y1, x1, y2, x2]` with `y1 < y2`, `x1 < x2`. + +### Shape warning (common failure) + +Bbox values are normalized to 0–1000 in BOTH axes. A square `[0, 0, 500, 500]` is square only on a square frame; on 16:9 it becomes a wide rectangle, on 9:16 a tall rectangle. Most bbox failures (extra subjects, duplicates, mis-scaled objects) come from this mismatch. + +For round objects or square on-screen regions, scale spans so `(x2-x1)/(y2-y1) ≈ W/H`. For single-subject prompts on wide frames, prefer narrower x-spans. For multi-subject prompts, give each a tight bbox so no one bbox dominates and invites a duplicate. + +## SPECIFICITY — commit to one value + +This JSON feeds a diffusion model. Leave nothing for the model to invent or choose. + +**Banned hedge phrasings** (in elements and background): `things like`, `such as`, `e.g.`, `for example`, `or similar`, `various`, `could include`, `might be`, `some kind of`, `style of`. Replace with concrete nouns, counts, colors, materials, poses. + +**Banned alternative listings for one property:** `pale institutional off-white or pale green`, `oak or walnut`, `cream or ivory`, `late afternoon or early evening`, `italic serif or italic sans-serif`, `bold or semibold`. Pick ONE and commit. `or` is reserved for the loader's exclusive-choice idiom (`'YES' or 'NO'`), not captioner hedging. + +**Typography specifically:** name ONE typeface category (serif OR sans-serif OR display OR script OR monospace), ONE weight (bold/regular/light/medium), ONE style (italic OR upright). Never two joined by `or`. + +**Banned "implied/suggested" hedges:** `a desk corner implied`, `a chair suggested beneath the figure`, `a building hinted at`, `a shadow that reads as a person`. If it's in the scene, paint it concretely. If it isn't, leave it out. Forbidden words: `implied, suggested, hinted, barely visible, possibly, perhaps, maybe, might be, could be, reads as, almost`. + +**Exhaustive content preservation.** When the user provides enumerable content — schedules, itineraries, lists, menu items, steps, names, times — every item must appear in the output. Use as many text elements as needed; never sacrifice completeness for layout. + +**Named prompt elements MUST appear.** Every explicitly-named visual unit in the user prompt MUST appear as its own element: +- Input `text:` sections — every entry becomes its own text element, verbatim. Zero tolerance: 3 entries in input → ≥3 text elements in output. Empty `text: []` is the only case where text elements may be omitted on that basis. +- Quoted strings (single or double quotes) — each is its own text element. +- Speech bubbles / dialogue callouts / thought bubbles / captions — each gets a text element for the quoted string AND an obj element for the bubble/balloon/container. +- Named decorative elements (`small medical cross icon top-left`, `airplane arc trajectory`, `flame-lick flourish at the tail`) — each gets its own obj. +- Named badges / chips / CTAs / strips — each gets its own obj (and text if it carries a quoted string). +- Named accents / graphic devices (`hairline rule`, `dot grid`, `accent line`, `divider`) — each gets its own obj UNLESS it's a scene-wide overlay belonging in `background`. + +**Test before emitting:** count named visual units in the user prompt; element list must contain at least that many. + +**No placeholder enumeration.** When the imagined image contains a sequentially-numbered, alphabetically-labeled, or otherwise individually-identified set (stones numbered 1–50, parking spaces A1–A20, place cards `1st`–`12th`, a periodic table of 118 elements, a calendar grid of 31 dates, a 22-name team roster), EACH item is its own element. No `etc.`, no `and so on`, no `6 through 49`, no single obj grouping all into one cluster. List ALL of them. + +The "dense unenumerable group" exception (crowd of thousands, field of wildflowers, starry sky) does NOT apply to enumerable sets — if items are sequentially identified, they're enumerable BY DEFINITION. + +**Don't invent visual concepts the user didn't ask for.** Forbidden without explicit user request: `glitch art`, `wireframe overlay`, `mesh that fragments the body`, `digital artifacts`, `dissolved`, `decompose`. If the prompt asks for a cinematic photo of a journalist, render a cinematic photo of a journalist — not a glitch-art composite. + +## PLANNING — turn the user idea into elements + +### 1. Pick a medium + +`photograph | illustration | 3d_render | painting | graphic_design` — this is the `medium` token (photograph ⇒ `photo`, all others ⇒ `art_style`), and it also frames HLD/background prose naturally. + +Decision: **DESIGNED artifact vs CAPTURED / DRAWN / RENDERED moment.** +- **graphic_design** — poster, book cover, album cover, magazine cover, flyer, banner, social post, sticker, logo, wordmark, packaging, app icon, UI mockup, infographic, menu, greeting card, ticket, signage. If a human designer would sit at a desk to make it. +- **photograph** — portrait, landscape, lifestyle, street, sport, wildlife, food, product, fashion editorial (when described as a photograph). Default for ambiguous everyday scenes. +- **illustration** — cartoon, anime, manga, comic, ink, vector, pixel art, children's book illustration, named studios (Ghibli, KyoAni, Pixar 2D). +- **painting** — watercolor, oil, gouache, acrylic, traditional painterly work. +- **3d_render** — CGI, octane/unreal/blender, hyperrealistic product render, arch viz, isometric low-poly, voxel, named 3D studios. + +Silent / ambiguous → photograph (default). The subject's reality status does NOT override this default — wizards, dragons, aliens, robots in a photograph are valid; the brief must explicitly ASK for illustration / painting / render to get one. + +Imperative verbs at the start ("Illustrate a…", "Paint a…", "Draw a…", "Render a…") are NOT medium signals — they mean "depict / show". Default to photograph unless an explicit medium-noun or style name appears. + +### 2. Style commitment + +Inside HLD/background prose, name the style ONCE (`Studio Ghibli animation`, `Pixar 3D animation`, `35mm film photograph`, `iPhone photo`, `editorial digital painting`, `flat vector illustration`). Keep it short — recognizable style names are enough; the renderer knows them. Don't append technique detail (`with hand-painted gouache backgrounds`) on top of well-known names. + +**"Professional picture/photo/portrait" of a person means PROFESSIONAL CONTEXT, not professional camera equipment.** Read as corporate headshot, LinkedIn profile, business bio — neutral business attire, soft even daylight, neutral backdrop, friendly approachable expression. NOT dramatic studio rim-lighting, creamy DSLR bokeh, dark moody backdrop. + +### 3. Photoreal defaults — AVOID "warm" + +For photographic prompts (no specified medium beyond `photo`/`photorealistic`/`selfie`/real-world scene): +- Default to iPhone aesthetic — phone snapshot, ambient natural light, neutral white balance, accurate (not flattering) skin tones, ordinary framing. AVOID DSLR-magazine markers (creamy bokeh, telephoto compression, dramatic rim lighting, cinematic grade) — those signal AI-generation. +- Default lighting framing: `natural daylight`, `overcast daylight`, `diffused daylight`, `cool-neutral white balance`. The word **"warm"** (in any phrase: `warm light`, `warm window light`, `warm tone`, `warm grading`) is BANNED as a grading adjective — it triggers the amber/golden AI look that ruins photorealism. When a scene physically has a warm-coloured light source (candle, sodium streetlamp, sunset), describe the SOURCE concretely (`candle flame`, `sodium streetlamp`) and the colour of the LIGHT POOL (`amber pool from the candle`) — but the global grade stays neutral. +- Default composition: prefer non-centered framing (off-center, rule-of-thirds, asymmetrical, leading lines) for portraits, products, single-subject scenes. Use centered framing ONLY when the prompt explicitly calls for it (`centered`, `symmetrical`, `mandala`, `kaleidoscope`) or when the genre is inherently symmetric. +- No motion blur in candid/realistic/iPhone-aesthetic photos. Motion blur is a craft signature (long-exposure pans, light streaks); using it in a candid signals AI. Real phone snapshots freeze the moment. +- Saturation: don't stack `vibrant + bright + intense + saturated + electric + neon` for a neutral subject. Mention saturation ONCE (in HLD or background) only when the prompt explicitly asks. + +### 4. Populate underspecified scenes + +When the brief is sparse, don't render only what's explicitly named. Real scenes are populated. Add believable secondary subjects, micro-props that imply the subject's life, environmental texture, small narrative moments. Each invented element should belong in the world the brief implies — a paddy-field food stall plausibly has a chicken, a sauce bowl, a hand-painted price sign, a lantern. + +**Populate by depth layer.** Foreground (often-skipped), midground, background — each gets its own content. A foreground crop (an out-of-focus leaf at the bottom corner, the rim of a bowl, a fly mid-air close to camera) separates a real photograph from a postcard. + +**Commit to a specific cultural / regional identity.** "Southeast Asian village" is a hedge that produces generic AI visuals. "Vietnamese pho stall by the rice paddies outside Hoi An" is a real place. Specific commitment shapes architecture, signage script, food, dress, props. + +**Built environments need text everywhere.** Real shops, stalls, restaurants, vehicles, signage carry text on practically every surface. Generate text generously: shop name sign, sub-signs (`OPEN` / `TODAY'S SPECIAL`), menu board with handwritten items, price labels, jar/bottle labels, name tags, posters, fortune slips, vehicle/equipment labels, sponsor logos. `text: []` is almost always wrong for built environments — if your scene has a shop/stall/restaurant/workshop/market/vehicle, populate text. Specific content, never `various labels` or `menu items`. + +**Override:** when the brief explicitly says `minimal`, `sparse`, `empty`, `lonely`, `isolated`, `quiet`, `still`, `negative space`, `alone`, `single subject`, `in the middle of nowhere`, respect the restraint and skip populate. + +**Fantastical / sci-fi / fantasy / futuristic briefs get a populate bonus.** Stack sky drama (galaxies, ringed planets, multiple moons, nebulae), opposing focal points (volcano right / waterfall left), mid-distance scale anchors (crystal columns, futuristic cityscape, megastructures), light/energy effects throughout, exotic architecture/geology, deeply saturated palettes. + +## TEXT HANDLING + +For each text element: +- `text` — literal characters appearing in the image, verbatim. Preserve diacritics, capitalization, punctuation. Never transliterate or strip. +- `bbox` — optional, same coordinate system as obj elements. +- `desc` — free-form prose covering size, location, font style, color, orientation, visual effects. + +**Sources of text to include:** +1. **User-quoted text** (single OR double quotes) — verbatim, exact characters. +2. **Format-required text** — headlines, taglines, author names, dates, venues, CTA copy, brand names, publisher marks, edition numbers (when format implies them). +3. **In-scene contextual text** — signage, labels, license plates, badges, jersey numbers, t-shirt prints, awnings, neon signs, name tags. +4. **Numeric content** — race numbers, jersey numbers, dates, prices, scores, time displays, address numbers. Numbers ARE text. +5. **Prominent product brand text** — if an element names a prominent product (bottle, cosmetic, package, beverage) and the user didn't supply a real brand, invent a complete brand identity and list every label as text elements. + +**Rules:** +- Exhaustive: if a viewer could read it, it goes in the list. +- Each text element appears ONCE in the list. Do NOT also describe its characters in `description` — refer by role/position instead. +- Use `\n` for line breaks WITHIN a single text element (multi-line sign, stacked headline). Use SEPARATE list items for visually distinct text blocks. +- For stylized hero typography where each letter is a distinct visual unit, stack with `\n` at natural word breaks — long single-line stylized titles produce typos and dropped letters. e.g., `"ENTRE\nVERSOS E\nCONTOS"` not `"ENTRE VERSOS E CONTOS"`. +- **Language scoping:** `scene`/`elements`/`description`/position descriptors are always in ENGLISH regardless of the user's brief language. Only the literal `text` field characters follow the user's brief language. Portuguese brief → English prose + Portuguese `text:` content. + +## POP CULTURE, BRANDS, NAMED REFERENCES + +When the user idea names or clearly implies a brand, trademark, product (sneaker/car/device), public figure, athlete, musician, actor, fictional character, film, show, game, franchise, team — the output MUST carry an explicit named reference in the relevant element `desc`, not a generic stand-in describing the look. + +Don't replace `Nike Dunk Low Panda` with `black and white retro sneakers`, `Spider-Man` with `a red-and-blue masked superhero`, `The Beatles` with `four men in matching suits` — unless the user asked for an anonymous lookalike. Name the specific thing the user pointed at. + +## TRANSPARENT BACKGROUND + +If the user's idea calls for transparent background, transparent canvas, alpha channel, cutout/isolated subject, sticker-style with no backdrop, or similar, the `background` field MUST be exactly this string, verbatim and nothing else: `transparent background` + +Do not paraphrase (no `clear backdrop`, `empty alpha`, `no background`, `PNG transparency`). + +In `high_level_description`, include the literal phrase `on a transparent background`. + +[USER] +TARGET IMAGE ASPECT RATIO: {{aspect_ratio}} (width:height). +User idea: {{original_prompt}} +""" diff --git a/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_upsample_prompt.py b/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_upsample_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..b922d8374770f50fc39fe4f53b208741597735ca --- /dev/null +++ b/ai-toolkit/extensions_built_in/captioner/prompts/ideogram4_upsample_prompt.py @@ -0,0 +1,100 @@ +ideogram4_upsample_prompt = """ +[META] +frozen: false +description: Faithful upsampler — lays a user prompt into the structured JSON caption without inventing or embellishing. Preserves triggers/names/styles exactly. Thinking off. +thinking_mode: disabled + +[SYSTEM] +You convert a user prompt into a structured JSON caption an image renderer can consume. You receive the user prompt plus a target aspect ratio, and you emit ONE JSON object. Your job is to LAY OUT what the user described into the required structure — concrete background, elements, bounding boxes, and text. You do NOT invent, expand, populate, or embellish beyond what the structure requires. + +## FIDELITY — read first, applies above everything else + +- **Preserve triggers/tokens EXACTLY.** Any trigger word, unique token, or identifier in the prompt — `[trigger]`, `sks`, `ohwx man`, a code name, a brand token, a person's name — must appear in the output VERBATIM: same characters, case, and brackets. Never paraphrase, translate, pluralize, split, correct, or drop it. Put it in the `desc` (and `high_level_description`) of the element it refers to. +- **Named person → no invented appearance.** If the prompt refers to a person by a name or trigger, do NOT describe or imagine their appearance — no face, hair, skin tone, age, body, or clothing unless the user explicitly stated it. Refer to them by the exact name/trigger and state ONLY what the prompt gives (action, pose, placement). Their identity is carried by the name alone. +- **Named style → no invented style detail.** If a style, medium, artist, or look is named (or carried by a trigger), reference it exactly as given and do NOT describe or elaborate its characteristics. +{{mode_directive}} + +## OUTPUT CONTRACT — exactly three top-level keys, in this order: + +```json +{"high_level_description":"...","style_description":{ ...see STYLE DESCRIPTION... },"compositional_deconstruction":{"background":"...","elements":[ ... ]}} +``` + +- Emit a SINGLE-LINE MINIFIED JSON object — no markdown fences, no commentary, no other top-level keys. +- Preserve non-ASCII characters as-is (CJK, Cyrillic, Arabic, accented Latin). Never escape them as unicode code-point sequences or transliterate. +- Use SINGLE quotes for embedded text references in prose fields (`'Joe's Diner'`). The `text` field is the exception — it holds verbatim characters. + +### Target aspect ratio (input only — never emit it) + +The user message gives a target aspect ratio as `W:H` (or `auto`). Use it ONLY to size your bounding boxes correctly (a box is square only on a square frame). Do NOT emit an `aspect_ratio` key — it is not part of the output. + +### `high_level_description` (50-word cap) + +One short sentence, reads like a natural prompt, starts with the subject — no "this image shows". Names the subject(s), any trigger/name verbatim, and the overall composition. Don't enumerate fine detail. + +## STYLE DESCRIPTION — the `style_description` block (always required) + +A nested object, filled FROM the prompt. It carries EXACTLY ONE render key — `photo` for photographs, `art_style` for everything else — NEVER both. Key order is strict and branch-dependent: + +- **Photograph** → `aesthetics`, `lighting`, `photo`, `medium`, `color_palette` +- **Non-photo** (illustration / 3D / painting / graphic design) → `aesthetics`, `lighting`, `medium`, `art_style`, `color_palette` + +Fields: +- `aesthetics` — the overall mood/aesthetic in a short phrase. +- `lighting` — the lighting (direction, quality, colour). Describe a warm-coloured source concretely; never use the bare word `warm` as a grade. +- `photo` (photographs ONLY) — the camera/film capture spec (framing, grain, focus). +- `art_style` (non-photo ONLY) — the rendering technique (`flat vector, clean edges`; `octane 3D render`; `loose watercolor`). +- `medium` — exactly one token: `photograph` / `illustration` / `3d_render` / `painting` / `graphic_design`. Photograph ⇒ use `photo`; any other ⇒ use `art_style`. +- `color_palette` — an array of dominant colours as UPPERCASE `#RRGGBB` strings (`"#1B3A5C"`), up to 16, ordered most → least dominant. ALWAYS the last key. + +Respect FIDELITY: if the prompt NAMES a style, medium, artist, or look, put it in these fields BY NAME (e.g. `medium`/`art_style`/`aesthetics`) and do NOT invent its characteristics. Pull lighting and colours from what the prompt states. In faithful mode, only commit to a value the prompt implies, keeping the rest minimal; in creative mode you may infer fitting style values — but never elaborate a named style and never override what the user gave. + +## ELEMENTS + +Each element is one of (keys in EXACTLY this order): +``` +{"type":"obj","bbox":[y1,x1,y2,x2],"desc":"..."} +{"type":"text","bbox":[y1,x1,y2,x2],"text":"LINE ONE\nLINE TWO","desc":"..."} +``` +`bbox` is OPTIONAL per element (see BBOX). Do NOT emit a per-element `color_palette` — an element's colours belong in its `desc` as prose; the only colour-conditioning field is the top-level `style_description.color_palette`. + +- **One coherent subject = ONE element.** A person, animal, vehicle, building, or plant is a single element; its parts are attributes of that element's `desc`, never separate elements. Multiple distinct subjects = multiple elements (one each). +- **`desc`:** identity first, then only the attributes the user gave (or that the structure plainly needs). For a named person/trigger: name + action/pose/placement ONLY, no appearance. For a generic un-named subject, you may state the concrete attributes the prompt implies, but do not invent an identity or backstory. + +## BACKGROUND — the scene shell only + +`background` describes the shell: walls/finishes, floor/ground, sky, ambient light, and distant out-of-focus context. + +- The floor/ground/turf/pavement, sky, horizon, and distant crowds live in `background` ONLY — never as obj elements. (A floor emitted as an obj clips standing subjects' legs.) +- **No double-counting:** anything named in `background` must NOT also be an obj element. +- Don't smuggle furniture or people into `background` as a "receding arrangement" — those are foreground elements. +- If the prompt asks for a transparent/cutout background, set `background` to exactly: `transparent background` (and include `on a transparent background` in the HLD). + +## BBOX + +Coordinates are normalized to 0–1000 in BOTH axes, top-left origin. Format `[y1, x1, y2, x2]` with `y1 < y2`, `x1 < x2`. + +A box is square only on a square frame; on a wide or tall frame the same numbers stretch. For round or square on-screen subjects, scale the spans so `(x2-x1)/(y2-y1) ≈ W/H`. Include bboxes where position matters; omit them for dense/uncountable fills (crowds, starfields). + +## TEXT + +- Every quoted string in the prompt becomes its own `text` element, with `text` = the verbatim characters (preserve case, punctuation, diacritics, and any trigger). Use `\n` for line breaks within one text block; separate blocks get separate elements. +- Include clearly in-scene text (a sign, a label) only when the user asked for it — do not invent signage or brand copy. +- Prose fields (`desc`, `background`, `high_level_description`) are always in ENGLISH; only the `text` field follows the prompt's language. + +## SPECIFICITY + +- For details the user GAVE, commit to one concrete value — no hedging (`things like`, `such as`, `various`), no alternatives (`oak or walnut`). +- For details the user did NOT give, add a single concrete value only when the structure requires it (e.g. a plain background shell); otherwise leave it out. +- Never hedge, never invent appearance for a named person, and never invent characteristics for a named style. + +## ADDITIONAL INSTRUCTIONS + +Honor the following extra instructions from the user. They must NEVER override the OUTPUT CONTRACT, the FIDELITY rules, or the structure above. + +{{user_instructions}} + +[USER] +TARGET IMAGE ASPECT RATIO: {{aspect_ratio}} (width:height). +User prompt: {{original_prompt}} +""" diff --git a/ai-toolkit/extensions_built_in/concept_replacer/ConceptReplacer.py b/ai-toolkit/extensions_built_in/concept_replacer/ConceptReplacer.py new file mode 100644 index 0000000000000000000000000000000000000000..1600e8e1851c402f5468d6c48fdc41ec2d4487fb --- /dev/null +++ b/ai-toolkit/extensions_built_in/concept_replacer/ConceptReplacer.py @@ -0,0 +1,151 @@ +import random +from collections import OrderedDict +from torch.utils.data import DataLoader +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +import torch +from jobs.process import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class ConceptReplacementConfig: + def __init__(self, **kwargs): + self.concept: str = kwargs.get('concept', '') + self.replacement: str = kwargs.get('replacement', '') + + +class ConceptReplacer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + replacement_list = self.config.get('replacements', []) + self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + + # textual inversion + if self.embedding is not None: + # set text encoder to train. Not sure if this is necessary but diffusers example did it + self.sd.text_encoder.train() + + def hook_train_loop(self, batch): + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + network_weight_list = batch.get_network_weight_list() + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + batch_replacement_list = [] + # get a random replacement for each prompt + for prompt in conditioned_prompts: + replacement = random.choice(self.replacement_list) + batch_replacement_list.append(replacement) + + # build out prompts + concept_prompts = [] + replacement_prompts = [] + for idx, replacement in enumerate(batch_replacement_list): + prompt = conditioned_prompts[idx] + + # insert shuffled concept at beginning and end of prompt + shuffled_concept = [x.strip() for x in replacement.concept.split(',')] + random.shuffle(shuffled_concept) + shuffled_concept = ', '.join(shuffled_concept) + concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}") + + # insert replacement at beginning and end of prompt + shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')] + random.shuffle(shuffled_replacement) + shuffled_replacement = ', '.join(shuffled_replacement) + replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}") + + # predict the replacement without network + conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype) + + replacement_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + del conditional_embeds + replacement_pred = replacement_pred.detach() + + self.optimizer.zero_grad() + flush() + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding: + grad_on_text_encoder = True + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + with network: + with torch.set_grad_enabled(grad_on_text_encoder): + # embed the prompts + conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype) + if not grad_on_text_encoder: + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + self.optimizer.zero_grad() + flush() + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.embedding is not None: + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + # reset network multiplier + network.multiplier = 1.0 + + return loss_dict diff --git a/ai-toolkit/extensions_built_in/concept_replacer/__init__.py b/ai-toolkit/extensions_built_in/concept_replacer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69dc731141158741d7b5165a7e5dc2f77b467051 --- /dev/null +++ b/ai-toolkit/extensions_built_in/concept_replacer/__init__.py @@ -0,0 +1,26 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class ConceptReplacerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "concept_replacer" + + # name is the name of the extension for printing + name = "Concept Replacer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ConceptReplacer import ConceptReplacer + return ConceptReplacer + + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ConceptReplacerExtension, +] diff --git a/ai-toolkit/extensions_built_in/concept_replacer/config/train.example.yaml b/ai-toolkit/extensions_built_in/concept_replacer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793d5d55b9a282d58b53f0e6fafba0b5aa66f7af --- /dev/null +++ b/ai-toolkit/extensions_built_in/concept_replacer/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/ai-toolkit/extensions_built_in/concept_slider/ConceptSliderTrainer.py b/ai-toolkit/extensions_built_in/concept_slider/ConceptSliderTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9fce74e12b826e1920fec2b7fb9bed13f49c31 --- /dev/null +++ b/ai-toolkit/extensions_built_in/concept_slider/ConceptSliderTrainer.py @@ -0,0 +1,302 @@ +from collections import OrderedDict +from typing import Optional + +import torch + +from extensions_built_in.sd_trainer.DiffusionTrainer import DiffusionTrainer +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.train_tools import get_torch_dtype + + +class ConceptSliderTrainerConfig: + def __init__(self, **kwargs): + self.guidance_strength: float = kwargs.get("guidance_strength", 3.0) + self.anchor_strength: float = kwargs.get("anchor_strength", 1.0) + self.positive_prompt: str = kwargs.get("positive_prompt", "") + self.negative_prompt: str = kwargs.get("negative_prompt", "") + self.target_class: str = kwargs.get("target_class", "") + self.anchor_class: Optional[str] = kwargs.get("anchor_class", None) + + +def norm_like_tensor(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Normalize the tensor to have the same mean and std as the target tensor.""" + tensor_mean = tensor.mean() + tensor_std = tensor.std() + target_mean = target.mean() + target_std = target.std() + normalized_tensor = (tensor - tensor_mean) / ( + tensor_std + 1e-8 + ) * target_std + target_mean + return normalized_tensor + + +class ConceptSliderTrainer(DiffusionTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.do_guided_loss = True + + self.slider: ConceptSliderTrainerConfig = ConceptSliderTrainerConfig( + **self.config.get("slider", {}) + ) + + self.positive_prompt = self.slider.positive_prompt + self.positive_prompt_embeds: Optional[PromptEmbeds] = None + self.negative_prompt = self.slider.negative_prompt + self.negative_prompt_embeds: Optional[PromptEmbeds] = None + self.target_class = self.slider.target_class + self.target_class_embeds: Optional[PromptEmbeds] = None + self.anchor_class = self.slider.anchor_class + self.anchor_class_embeds: Optional[PromptEmbeds] = None + + def hook_before_train_loop(self): + # do this before calling parent as it unloads the text encoder if requested + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to("cpu") + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + self.positive_prompt_embeds = ( + self.sd.encode_prompt( + [self.positive_prompt], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + self.target_class_embeds = ( + self.sd.encode_prompt( + [self.target_class], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + self.negative_prompt_embeds = ( + self.sd.encode_prompt( + [self.negative_prompt], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + if self.anchor_class is not None: + self.anchor_class_embeds = ( + self.sd.encode_prompt( + [self.anchor_class], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + # call parent + super().hook_before_train_loop() + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: "DataLoaderBatchDTO", + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs, + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + + # do out prior preds first + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + self.sd.unet.eval() + noisy_latents = noisy_latents.to(self.device_torch, dtype=dtype).detach() + + batch_size = noisy_latents.shape[0] + + positive_embeds = concat_prompt_embeds( + [self.positive_prompt_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + target_class_embeds = concat_prompt_embeds( + [self.target_class_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + negative_embeds = concat_prompt_embeds( + [self.negative_prompt_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + + if self.anchor_class_embeds is not None: + anchor_embeds = concat_prompt_embeds( + [self.anchor_class_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + + if self.anchor_class_embeds is not None: + # if we have an anchor, do it + combo_embeds = concat_prompt_embeds( + [ + positive_embeds, + target_class_embeds, + negative_embeds, + anchor_embeds, + ] + ) + num_embeds = 4 + else: + combo_embeds = concat_prompt_embeds( + [positive_embeds, target_class_embeds, negative_embeds] + ) + num_embeds = 3 + + # do them in one batch, VRAM should handle it since we are no grad + combo_pred = self.sd.predict_noise( + latents=torch.cat([noisy_latents] * num_embeds, dim=0), + conditional_embeddings=combo_embeds, + timestep=torch.cat([timesteps] * num_embeds, dim=0), + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + positive_pred, neutral_pred, negative_pred, anchor_target = ( + combo_pred.chunk(4, dim=0) + ) + else: + anchor_target = None + positive_pred, neutral_pred, negative_pred = combo_pred.chunk(3, dim=0) + + # calculate the targets + guidance_scale = self.slider.guidance_strength + + # enhance_positive_target = neutral_pred + guidance_scale * ( + # positive_pred - negative_pred + # ) + # enhance_negative_target = neutral_pred + guidance_scale * ( + # negative_pred - positive_pred + # ) + # erase_negative_target = neutral_pred - guidance_scale * ( + # negative_pred - positive_pred + # ) + # erase_positive_target = neutral_pred - guidance_scale * ( + # positive_pred - negative_pred + # ) + + positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred) + negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred) + + enhance_positive_target = neutral_pred + guidance_scale * positive + enhance_negative_target = neutral_pred + guidance_scale * negative + erase_negative_target = neutral_pred - guidance_scale * negative + erase_positive_target = neutral_pred - guidance_scale * positive + + # normalize to neutral std/mean + enhance_positive_target = norm_like_tensor( + enhance_positive_target, neutral_pred + ) + enhance_negative_target = norm_like_tensor( + enhance_negative_target, neutral_pred + ) + erase_negative_target = norm_like_tensor( + erase_negative_target, neutral_pred + ) + erase_positive_target = norm_like_tensor( + erase_positive_target, neutral_pred + ) + + if was_unet_training: + self.sd.unet.train() + + # restore network + if self.network is not None: + self.network.is_active = was_network_active + + if self.anchor_class_embeds is not None: + # do a grad inference with our target prompt + embeds = concat_prompt_embeds([target_class_embeds, anchor_embeds]).to( + self.device_torch, dtype=dtype + ) + + noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0).to( + self.device_torch, dtype=dtype + ) + timesteps = torch.cat([timesteps, timesteps], dim=0) + else: + embeds = target_class_embeds.to(self.device_torch, dtype=dtype) + + # do positive first + self.network.set_multiplier(1.0) + pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=embeds, + timestep=timesteps, + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + class_pred, anchor_pred = pred.chunk(2, dim=0) + else: + class_pred = pred + anchor_pred = None + + # enhance positive loss + enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_positive_target) + + erase_loss = torch.nn.functional.mse_loss(class_pred, erase_negative_target) + + if anchor_target is None: + anchor_loss = torch.zeros_like(erase_loss) + else: + anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) + + anchor_loss = anchor_loss * self.slider.anchor_strength + + # send backward now because gradient checkpointing needs network polarity intact + total_pos_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0 + total_pos_loss.backward() + total_pos_loss = total_pos_loss.detach() + + # now do negative + self.network.set_multiplier(-1.0) + pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=embeds, + timestep=timesteps, + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + class_pred, anchor_pred = pred.chunk(2, dim=0) + else: + class_pred = pred + anchor_pred = None + + # enhance negative loss + enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_negative_target) + erase_loss = torch.nn.functional.mse_loss(class_pred, erase_positive_target) + + if anchor_target is None: + anchor_loss = torch.zeros_like(erase_loss) + else: + anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) + anchor_loss = anchor_loss * self.slider.anchor_strength + total_neg_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0 + total_neg_loss.backward() + total_neg_loss = total_neg_loss.detach() + + self.network.set_multiplier(1.0) + + total_loss = (total_pos_loss + total_neg_loss) / 2.0 + + # add a grad so backward works right + total_loss.requires_grad_(True) + return total_loss diff --git a/ai-toolkit/extensions_built_in/concept_slider/__init__.py b/ai-toolkit/extensions_built_in/concept_slider/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a624b1850bb3b3071ad6520765e50e701bf00cb --- /dev/null +++ b/ai-toolkit/extensions_built_in/concept_slider/__init__.py @@ -0,0 +1,26 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class ConceptSliderTrainerTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "concept_slider" + + # name is the name of the extension for printing + name = "Concept Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ConceptSliderTrainer import ConceptSliderTrainer + + return ConceptSliderTrainer + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ConceptSliderTrainerTrainer +] diff --git a/ai-toolkit/extensions_built_in/dataset_tools/DatasetTools.py b/ai-toolkit/extensions_built_in/dataset_tools/DatasetTools.py new file mode 100644 index 0000000000000000000000000000000000000000..d969b77aa9fa9e0fcce001e2fb5d102d60058f6a --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/DatasetTools.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +import gc +import torch +from jobs.process import BaseExtensionProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class DatasetTools(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + super().run() + + raise NotImplementedError("This extension is not yet implemented") diff --git a/ai-toolkit/extensions_built_in/dataset_tools/SuperTagger.py b/ai-toolkit/extensions_built_in/dataset_tools/SuperTagger.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb3c70e57ff3d2cc6b5dd23400d9a30b72243c7 --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/SuperTagger.py @@ -0,0 +1,196 @@ +import copy +import json +import os +from collections import OrderedDict +import gc +import traceback +import torch +from PIL import Image, ImageOps +from tqdm import tqdm + +from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo +from .tools.fuyu_utils import FuyuImageProcessor +from .tools.image_tools import load_image, ImageProcessor, resize_to_max +from .tools.llava_utils import LLaVAImageProcessor +from .tools.caption import default_long_prompt, default_short_prompt, default_replacements +from jobs.process import BaseExtensionProcess +from .tools.sync_tools import get_img_paths + +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +VERSION = 2 + + +class SuperTagger(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + parent_dir = config.get('parent_dir', None) + self.dataset_paths: list[str] = config.get('dataset_paths', []) + self.device = config.get('device', 'cuda') + self.steps: list[Step] = config.get('steps', []) + self.caption_method = config.get('caption_method', 'llava:default') + self.caption_prompt = config.get('caption_prompt', default_long_prompt) + self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt) + self.force_reprocess_img = config.get('force_reprocess_img', False) + self.caption_replacements = config.get('caption_replacements', default_replacements) + self.caption_short_replacements = config.get('caption_short_replacements', default_replacements) + self.master_dataset_dict = OrderedDict() + self.dataset_master_config_file = config.get('dataset_master_config_file', None) + if parent_dir is not None and len(self.dataset_paths) == 0: + # find all folders in the patent_dataset_path + self.dataset_paths = [ + os.path.join(parent_dir, folder) + for folder in os.listdir(parent_dir) + if os.path.isdir(os.path.join(parent_dir, folder)) + ] + else: + # make sure they exist + for dataset_path in self.dataset_paths: + if not os.path.exists(dataset_path): + raise ValueError(f"Dataset path does not exist: {dataset_path}") + + print(f"Found {len(self.dataset_paths)} dataset paths") + + self.image_processor: ImageProcessor = self.get_image_processor() + + def get_image_processor(self): + if self.caption_method.startswith('llava'): + return LLaVAImageProcessor(device=self.device) + elif self.caption_method.startswith('fuyu'): + return FuyuImageProcessor(device=self.device) + else: + raise ValueError(f"Unknown caption method: {self.caption_method}") + + def process_image(self, img_path: str): + root_img_dir = os.path.dirname(os.path.dirname(img_path)) + filename = os.path.basename(img_path) + filename_no_ext = os.path.splitext(filename)[0] + train_dir = os.path.join(root_img_dir, TRAIN_DIR) + train_img_path = os.path.join(train_dir, filename) + json_path = os.path.join(train_dir, f"{filename_no_ext}.json") + + # check if json exists, if it does load it as image info + if os.path.exists(json_path): + with open(json_path, 'r') as f: + img_info = ImgInfo(**json.load(f)) + else: + img_info = ImgInfo() + + # always send steps first in case other processes need them + img_info.add_steps(copy.deepcopy(self.steps)) + img_info.set_version(VERSION) + img_info.set_caption_method(self.caption_method) + + image: Image = None + caption_image: Image = None + + did_update_image = False + + # trigger reprocess of steps + if self.force_reprocess_img: + img_info.trigger_image_reprocess() + + # set the image as updated if it does not exist on disk + if not os.path.exists(train_img_path): + did_update_image = True + image = load_image(img_path) + if img_info.force_image_process: + did_update_image = True + image = load_image(img_path) + + # go through the needed steps + for step in copy.deepcopy(img_info.state.steps_to_complete): + if step == 'caption': + # load image + if image is None: + image = load_image(img_path) + if caption_image is None: + caption_image = resize_to_max(image, 1024, 1024) + + if not self.image_processor.is_loaded: + print('Loading Model. Takes a while, especially the first time') + self.image_processor.load_model() + + img_info.caption = self.image_processor.generate_caption( + image=caption_image, + prompt=self.caption_prompt, + replacements=self.caption_replacements + ) + img_info.mark_step_complete(step) + elif step == 'caption_short': + # load image + if image is None: + image = load_image(img_path) + + if caption_image is None: + caption_image = resize_to_max(image, 1024, 1024) + + if not self.image_processor.is_loaded: + print('Loading Model. Takes a while, especially the first time') + self.image_processor.load_model() + img_info.caption_short = self.image_processor.generate_caption( + image=caption_image, + prompt=self.caption_short_prompt, + replacements=self.caption_short_replacements + ) + img_info.mark_step_complete(step) + elif step == 'contrast_stretch': + # load image + if image is None: + image = load_image(img_path) + image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True) + did_update_image = True + img_info.mark_step_complete(step) + else: + raise ValueError(f"Unknown step: {step}") + + os.makedirs(os.path.dirname(train_img_path), exist_ok=True) + if did_update_image: + image.save(train_img_path) + + if img_info.is_dirty: + with open(json_path, 'w') as f: + json.dump(img_info.to_dict(), f, indent=4) + + if self.dataset_master_config_file: + # add to master dict + self.master_dataset_dict[train_img_path] = img_info.to_dict() + + def run(self): + super().run() + imgs_to_process = [] + # find all images + for dataset_path in self.dataset_paths: + raw_dir = os.path.join(dataset_path, RAW_DIR) + raw_image_paths = get_img_paths(raw_dir) + for raw_image_path in raw_image_paths: + imgs_to_process.append(raw_image_path) + + if len(imgs_to_process) == 0: + print(f"No images to process") + else: + print(f"Found {len(imgs_to_process)} to process") + + for img_path in tqdm(imgs_to_process, desc="Processing images"): + try: + self.process_image(img_path) + except Exception: + # print full stack trace + print(traceback.format_exc()) + continue + # self.process_image(img_path) + + if self.dataset_master_config_file is not None: + # save it as json + with open(self.dataset_master_config_file, 'w') as f: + json.dump(self.master_dataset_dict, f, indent=4) + + del self.image_processor + flush() diff --git a/ai-toolkit/extensions_built_in/dataset_tools/SyncFromCollection.py b/ai-toolkit/extensions_built_in/dataset_tools/SyncFromCollection.py new file mode 100644 index 0000000000000000000000000000000000000000..e65a35848e933fdab843d8677b6e5000e1393825 --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/SyncFromCollection.py @@ -0,0 +1,131 @@ +import os +import shutil +from collections import OrderedDict +import gc +from typing import List + +import torch +from tqdm import tqdm + +from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR +from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \ + get_img_paths +from jobs.process import BaseExtensionProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class SyncFromCollection(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + self.min_width = config.get('min_width', 1024) + self.min_height = config.get('min_height', 1024) + + # add our min_width and min_height to each dataset config if they don't exist + for dataset_config in config.get('dataset_sync', []): + if 'min_width' not in dataset_config: + dataset_config['min_width'] = self.min_width + if 'min_height' not in dataset_config: + dataset_config['min_height'] = self.min_height + + self.dataset_configs: List[DatasetSyncCollectionConfig] = [ + DatasetSyncCollectionConfig(**dataset_config) + for dataset_config in config.get('dataset_sync', []) + ] + print(f"Found {len(self.dataset_configs)} dataset configs") + + def move_new_images(self, root_dir: str): + raw_dir = os.path.join(root_dir, RAW_DIR) + new_dir = os.path.join(root_dir, NEW_DIR) + new_images = get_img_paths(new_dir) + + for img_path in new_images: + # move to raw + new_path = os.path.join(raw_dir, os.path.basename(img_path)) + shutil.move(img_path, new_path) + + # remove new dir + shutil.rmtree(new_dir) + + def sync_dataset(self, config: DatasetSyncCollectionConfig): + if config.host == 'unsplash': + get_images = get_unsplash_images + elif config.host == 'pexels': + get_images = get_pexels_images + else: + raise ValueError(f"Unknown host: {config.host}") + + results = { + 'num_downloaded': 0, + 'num_skipped': 0, + 'bad': 0, + 'total': 0, + } + + photos = get_images(config) + raw_dir = os.path.join(config.directory, RAW_DIR) + new_dir = os.path.join(config.directory, NEW_DIR) + raw_images = get_local_image_file_names(raw_dir) + new_images = get_local_image_file_names(new_dir) + + for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"): + try: + if photo.filename not in raw_images and photo.filename not in new_images: + download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height) + results['num_downloaded'] += 1 + else: + results['num_skipped'] += 1 + except Exception as e: + print(f" - BAD({photo.id}): {e}") + results['bad'] += 1 + continue + results['total'] += 1 + + return results + + def print_results(self, results): + print( + f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}") + + def run(self): + super().run() + print(f"Syncing {len(self.dataset_configs)} datasets") + all_results = None + failed_datasets = [] + for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True): + try: + results = self.sync_dataset(dataset_config) + if all_results is None: + all_results = {**results} + else: + for key, value in results.items(): + all_results[key] += value + + self.print_results(results) + except Exception as e: + print(f" - FAILED: {e}") + if 'response' in e.__dict__: + error = f"{e.response.status_code}: {e.response.text}" + print(f" - {error}") + failed_datasets.append({'dataset': dataset_config, 'error': error}) + else: + failed_datasets.append({'dataset': dataset_config, 'error': str(e)}) + continue + + print("Moving new images to raw") + for dataset_config in self.dataset_configs: + self.move_new_images(dataset_config.directory) + + print("Done syncing datasets") + self.print_results(all_results) + + if len(failed_datasets) > 0: + print(f"Failed to sync {len(failed_datasets)} datasets") + for failed in failed_datasets: + print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}") + print(f" - ERR: {failed['error']}") diff --git a/ai-toolkit/extensions_built_in/dataset_tools/__init__.py b/ai-toolkit/extensions_built_in/dataset_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b86d3cf5741ddd1892fda4fe24003012d770548b --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/__init__.py @@ -0,0 +1,43 @@ +from toolkit.extension import Extension + + +class DatasetToolsExtension(Extension): + uid = "dataset_tools" + + # name is the name of the extension for printing + name = "Dataset Tools" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .DatasetTools import DatasetTools + return DatasetTools + + +class SyncFromCollectionExtension(Extension): + uid = "sync_from_collection" + name = "Sync from Collection" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .SyncFromCollection import SyncFromCollection + return SyncFromCollection + + +class SuperTaggerExtension(Extension): + uid = "super_tagger" + name = "Super Tagger" + + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .SuperTagger import SuperTagger + return SuperTagger + + +AI_TOOLKIT_EXTENSIONS = [ + SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension +] diff --git a/ai-toolkit/extensions_built_in/dataset_tools/tools/caption.py b/ai-toolkit/extensions_built_in/dataset_tools/tools/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..370786a80a380749863dbfc4449bdb33c09a5daa --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/tools/caption.py @@ -0,0 +1,53 @@ + +caption_manipulation_steps = ['caption', 'caption_short'] + +default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.' +default_short_prompt = 'caption this image in less than ten words' + +default_replacements = [ + ("the image features", ""), + ("the image shows", ""), + ("the image depicts", ""), + ("the image is", ""), + ("in this image", ""), + ("in the image", ""), +] + + +def clean_caption(cap, replacements=None): + if replacements is None: + replacements = default_replacements + + # remove any newlines + cap = cap.replace("\n", ", ") + cap = cap.replace("\r", ", ") + cap = cap.replace(".", ",") + cap = cap.replace("\"", "") + + # remove unicode characters + cap = cap.encode('ascii', 'ignore').decode('ascii') + + # make lowercase + cap = cap.lower() + # remove any extra spaces + cap = " ".join(cap.split()) + + for replacement in replacements: + if replacement[0].startswith('*'): + # we are removing all text if it starts with this and the rest matches + search_text = replacement[0][1:] + if cap.startswith(search_text): + cap = "" + else: + cap = cap.replace(replacement[0].lower(), replacement[1].lower()) + + cap_list = cap.split(",") + # trim whitespace + cap_list = [c.strip() for c in cap_list] + # remove empty strings + cap_list = [c for c in cap_list if c != ""] + # remove duplicates + cap_list = list(dict.fromkeys(cap_list)) + # join back together + cap = ", ".join(cap_list) + return cap \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py b/ai-toolkit/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..60c69dbb503c724b2ecaa8fdfd5a702b43cee87c --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/tools/dataset_tools_config_modules.py @@ -0,0 +1,187 @@ +import json +from typing import Literal, Type, TYPE_CHECKING + +Host: Type = Literal['unsplash', 'pexels'] + +RAW_DIR = "raw" +NEW_DIR = "_tmp" +TRAIN_DIR = "train" +DEPTH_DIR = "depth" + +from .image_tools import Step, img_manipulation_steps +from .caption import caption_manipulation_steps + + +class DatasetSyncCollectionConfig: + def __init__(self, **kwargs): + self.host: Host = kwargs.get('host', None) + self.collection_id: str = kwargs.get('collection_id', None) + self.directory: str = kwargs.get('directory', None) + self.api_key: str = kwargs.get('api_key', None) + self.min_width: int = kwargs.get('min_width', 1024) + self.min_height: int = kwargs.get('min_height', 1024) + + if self.host is None: + raise ValueError("host is required") + if self.collection_id is None: + raise ValueError("collection_id is required") + if self.directory is None: + raise ValueError("directory is required") + if self.api_key is None: + raise ValueError(f"api_key is required: {self.host}:{self.collection_id}") + + +class ImageState: + def __init__(self, **kwargs): + self.steps_complete: list[Step] = kwargs.get('steps_complete', []) + self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', []) + + def to_dict(self): + return { + 'steps_complete': self.steps_complete + } + + +class Rect: + def __init__(self, **kwargs): + self.x = kwargs.get('x', 0) + self.y = kwargs.get('y', 0) + self.width = kwargs.get('width', 0) + self.height = kwargs.get('height', 0) + + def to_dict(self): + return { + 'x': self.x, + 'y': self.y, + 'width': self.width, + 'height': self.height + } + + +class ImgInfo: + def __init__(self, **kwargs): + self.version: int = kwargs.get('version', None) + self.caption: str = kwargs.get('caption', None) + self.caption_short: str = kwargs.get('caption_short', None) + self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])] + self.state = ImageState(**kwargs.get('state', {})) + self.caption_method = kwargs.get('caption_method', None) + self.other_captions = kwargs.get('other_captions', {}) + self._upgrade_state() + self.force_image_process: bool = False + self._requested_steps: list[Step] = [] + + self.is_dirty: bool = False + + def _upgrade_state(self): + # upgrades older states + if self.caption is not None and 'caption' not in self.state.steps_complete: + self.mark_step_complete('caption') + self.is_dirty = True + if self.caption_short is not None and 'caption_short' not in self.state.steps_complete: + self.mark_step_complete('caption_short') + self.is_dirty = True + if self.caption_method is None and self.caption is not None: + # added caption method in version 2. Was all llava before that + self.caption_method = 'llava:default' + self.is_dirty = True + + def to_dict(self): + return { + 'version': self.version, + 'caption_method': self.caption_method, + 'caption': self.caption, + 'caption_short': self.caption_short, + 'poi': [poi.to_dict() for poi in self.poi], + 'state': self.state.to_dict(), + 'other_captions': self.other_captions + } + + def mark_step_complete(self, step: Step): + if step not in self.state.steps_complete: + self.state.steps_complete.append(step) + if step in self.state.steps_to_complete: + self.state.steps_to_complete.remove(step) + self.is_dirty = True + + def add_step(self, step: Step): + if step not in self.state.steps_to_complete and step not in self.state.steps_complete: + self.state.steps_to_complete.append(step) + + def trigger_image_reprocess(self): + if self._requested_steps is None: + raise Exception("Must call add_steps before trigger_image_reprocess") + steps = self._requested_steps + # remove all image manipulationf from steps_to_complete + for step in img_manipulation_steps: + if step in self.state.steps_to_complete: + self.state.steps_to_complete.remove(step) + if step in self.state.steps_complete: + self.state.steps_complete.remove(step) + self.force_image_process = True + self.is_dirty = True + # we want to keep the order passed in process file + for step in steps: + if step in img_manipulation_steps: + self.add_step(step) + + def add_steps(self, steps: list[Step]): + self._requested_steps = [step for step in steps] + for stage in steps: + self.add_step(stage) + + # update steps if we have any img processes not complete, we have to reprocess them all + # if any steps_to_complete are in img_manipulation_steps + + is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete]) + order_has_changed = False + + if not is_manipulating_image: + # check to see if order has changed. No need to if already redoing it. Will detect if ones are removed + target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps] + current_img_manipulation_order = [step for step in self.state.steps_complete if + step in img_manipulation_steps] + if target_img_manipulation_order != current_img_manipulation_order: + order_has_changed = True + + if is_manipulating_image or order_has_changed: + self.trigger_image_reprocess() + + def set_caption_method(self, method: str): + if self._requested_steps is None: + raise Exception("Must call add_steps before set_caption_method") + if self.caption_method != method: + self.is_dirty = True + # move previous caption method to other_captions + if self.caption_method is not None and self.caption is not None or self.caption_short is not None: + self.other_captions[self.caption_method] = { + 'caption': self.caption, + 'caption_short': self.caption_short, + } + self.caption_method = method + self.caption = None + self.caption_short = None + # see if we have a caption from the new method + if method in self.other_captions: + self.caption = self.other_captions[method].get('caption', None) + self.caption_short = self.other_captions[method].get('caption_short', None) + else: + self.trigger_new_caption() + + def trigger_new_caption(self): + self.caption = None + self.caption_short = None + self.is_dirty = True + # check to see if we have any steps in the complete list and move them to the to_complete list + for step in self.state.steps_complete: + if step in caption_manipulation_steps: + self.state.steps_complete.remove(step) + self.state.steps_to_complete.append(step) + + def to_json(self): + return json.dumps(self.to_dict()) + + def set_version(self, version: int): + if self.version != version: + self.is_dirty = True + self.version = version diff --git a/ai-toolkit/extensions_built_in/dataset_tools/tools/fuyu_utils.py b/ai-toolkit/extensions_built_in/dataset_tools/tools/fuyu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..407da10c257b46bbd7c2ce70d4beebff0a7d5a89 --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/tools/fuyu_utils.py @@ -0,0 +1,66 @@ +from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer + +from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption +import torch +from PIL import Image + + +class FuyuImageProcessor: + def __init__(self, device='cuda'): + from transformers import FuyuProcessor, FuyuForCausalLM + self.device = device + self.model: FuyuForCausalLM = None + self.processor: FuyuProcessor = None + self.dtype = torch.bfloat16 + self.tokenizer: AutoTokenizer + self.is_loaded = False + + def load_model(self): + from transformers import FuyuProcessor, FuyuForCausalLM + model_path = "adept/fuyu-8b" + kwargs = {"device_map": self.device} + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=self.dtype, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + self.processor = FuyuProcessor.from_pretrained(model_path) + self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + self.is_loaded = True + + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs) + self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer) + + def generate_caption( + self, image: Image, + prompt: str = default_long_prompt, + replacements=default_replacements, + max_new_tokens=512 + ): + # prepare inputs for the model + # text_prompt = f"{prompt}\n" + + # image = image.convert('RGB') + model_inputs = self.processor(text=prompt, images=[image]) + model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in + model_inputs.items()} + + generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens) + prompt_len = model_inputs["input_ids"].shape[-1] + output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True) + output = clean_caption(output, replacements=replacements) + return output + + # inputs = self.processor(text=text_prompt, images=image, return_tensors="pt") + # for k, v in inputs.items(): + # inputs[k] = v.to(self.device) + + # # autoregressively generate text + # generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens) + # generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True) + # output = generation_text[0] + # + # return clean_caption(output, replacements=replacements) diff --git a/ai-toolkit/extensions_built_in/dataset_tools/tools/image_tools.py b/ai-toolkit/extensions_built_in/dataset_tools/tools/image_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d36073c0164047bdced3111a7230084e0b0bd187 --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/tools/image_tools.py @@ -0,0 +1,49 @@ +from typing import Literal, Type, TYPE_CHECKING, Union + +import cv2 +import numpy as np +from PIL import Image, ImageOps + +Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch'] + +img_manipulation_steps = ['contrast_stretch'] + +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + +if TYPE_CHECKING: + from .llava_utils import LLaVAImageProcessor + from .fuyu_utils import FuyuImageProcessor + +ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor'] + + +def pil_to_cv2(image): + """Convert a PIL image to a cv2 image.""" + return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + +def cv2_to_pil(image): + """Convert a cv2 image to a PIL image.""" + return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + +def load_image(img_path: str): + image = Image.open(img_path).convert('RGB') + try: + # transpose with exif data + image = ImageOps.exif_transpose(image) + except Exception as e: + pass + return image + + +def resize_to_max(image, max_width=1024, max_height=1024): + width, height = image.size + if width <= max_width and height <= max_height: + return image + + scale = min(max_width / width, max_height / height) + width = int(width * scale) + height = int(height * scale) + + return image.resize((width, height), Image.LANCZOS) diff --git a/ai-toolkit/extensions_built_in/dataset_tools/tools/llava_utils.py b/ai-toolkit/extensions_built_in/dataset_tools/tools/llava_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba38d6613c2ef5a8cd120d8906e5b8d236240bd --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/tools/llava_utils.py @@ -0,0 +1,85 @@ + +from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption + +import torch +from PIL import Image, ImageOps + +from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor + +img_ext = ['.jpg', '.jpeg', '.png', '.webp'] + + +class LLaVAImageProcessor: + def __init__(self, device='cuda'): + try: + from llava.model import LlavaLlamaForCausalLM + except ImportError: + # print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") + print( + "You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git") + raise + self.device = device + self.model: LlavaLlamaForCausalLM = None + self.tokenizer: AutoTokenizer = None + self.image_processor: CLIPImageProcessor = None + self.is_loaded = False + + def load_model(self): + from llava.model import LlavaLlamaForCausalLM + + model_path = "4bit/llava-v1.5-13b-3GB" + # kwargs = {"device_map": "auto"} + kwargs = {"device_map": self.device} + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + vision_tower = self.model.get_vision_tower() + if not vision_tower.is_loaded: + vision_tower.load_model() + vision_tower.to(device=self.device) + self.image_processor = vision_tower.image_processor + self.is_loaded = True + + def generate_caption( + self, image: + Image, prompt: str = default_long_prompt, + replacements=default_replacements, + max_new_tokens=512 + ): + from llava.conversation import conv_templates, SeparatorStyle + from llava.utils import disable_torch_init + from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria + # question = "how many dogs are in the picture?" + disable_torch_init() + conv_mode = "llava_v0" + conv = conv_templates[conv_mode].copy() + roles = conv.roles + image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda() + + inp = f"{roles[0]}: {prompt}" + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + raw_prompt = conv.get_prompt() + input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, + return_tensors='pt').unsqueeze(0).cuda() + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, images=image_tensor, do_sample=True, temperature=0.1, + max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], + top_p=0.8 + ) + outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + conv.messages[-1][-1] = outputs + output = outputs.rsplit('', 1)[0] + return clean_caption(output, replacements=replacements) diff --git a/ai-toolkit/extensions_built_in/dataset_tools/tools/sync_tools.py b/ai-toolkit/extensions_built_in/dataset_tools/tools/sync_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..143cc6bb93a9269eeafffe737f96302e2ca40787 --- /dev/null +++ b/ai-toolkit/extensions_built_in/dataset_tools/tools/sync_tools.py @@ -0,0 +1,279 @@ +import os +import requests +import tqdm +from typing import List, Optional, TYPE_CHECKING + + +def img_root_path(img_id: str): + return os.path.dirname(os.path.dirname(img_id)) + + +if TYPE_CHECKING: + from .dataset_tools_config_modules import DatasetSyncCollectionConfig + +img_exts = ['.jpg', '.jpeg', '.webp', '.png'] + +class Photo: + def __init__( + self, + id, + host, + width, + height, + url, + filename + ): + self.id = str(id) + self.host = host + self.width = width + self.height = height + self.url = url + self.filename = filename + + +def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int): + if img_width > img_height: + scale = min_height / img_height + else: + scale = min_width / img_width + + new_width = int(img_width * scale) + new_height = int(img_height * scale) + + return new_width, new_height + + +def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]: + all_images = [] + next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos" + + while True: + response = requests.get(next_page, headers={ + "Authorization": f"{config.api_key}" + }) + response.raise_for_status() + data = response.json() + all_images.extend(data['media']) + if 'next_page' in data and data['next_page']: + next_page = data['next_page'] + else: + break + + photos = [] + for image in all_images: + new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height) + url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}" + filename = os.path.basename(image['src']['original']) + + photos.append(Photo( + id=image['id'], + host="pexels", + width=image['width'], + height=image['height'], + url=url, + filename=filename + )) + + return photos + + +def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]: + headers = { + # "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}" + "Authorization": f"Client-ID {config.api_key}" + } + # headers['Authorization'] = f"Bearer {token}" + + url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30" + response = requests.get(url, headers=headers) + response.raise_for_status() + res_headers = response.headers + # parse the link header to get the next page + # 'Link': '; rel="last", ; rel="next"' + has_next_page = False + if 'Link' in res_headers: + has_next_page = True + link_header = res_headers['Link'] + link_header = link_header.split(',') + link_header = [link.strip() for link in link_header] + link_header = [link.split(';') for link in link_header] + link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header] + link_header = {link[1]: link[0] for link in link_header} + + # get page number from last url + last_page = link_header['rel="last'] + last_page = last_page.split('?')[1] + last_page = last_page.split('&') + last_page = [param.split('=') for param in last_page] + last_page = {param[0]: param[1] for param in last_page} + last_page = int(last_page['page']) + + all_images = response.json() + + if has_next_page: + # assume we start on page 1, so we don't need to get it again + for page in tqdm.tqdm(range(2, last_page + 1)): + url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30" + response = requests.get(url, headers=headers) + response.raise_for_status() + all_images.extend(response.json()) + + photos = [] + for image in all_images: + new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height) + url = f"{image['urls']['raw']}&w={new_width}" + filename = f"{image['id']}.jpg" + + photos.append(Photo( + id=image['id'], + host="unsplash", + width=image['width'], + height=image['height'], + url=url, + filename=filename + )) + + return photos + + +def get_img_paths(dir_path: str): + os.makedirs(dir_path, exist_ok=True) + local_files = os.listdir(dir_path) + # remove non image files + local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts] + # make full path + local_files = [os.path.join(dir_path, file) for file in local_files] + return local_files + + +def get_local_image_ids(dir_path: str): + os.makedirs(dir_path, exist_ok=True) + local_files = get_img_paths(dir_path) + # assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg' + return set([os.path.basename(file).split('.')[0] for file in local_files]) + + +def get_local_image_file_names(dir_path: str): + os.makedirs(dir_path, exist_ok=True) + local_files = get_img_paths(dir_path) + # assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg' + return set([os.path.basename(file) for file in local_files]) + + +def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024): + img_width = photo.width + img_height = photo.height + + if img_width < min_width or img_height < min_height: + raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}") + + img_response = requests.get(photo.url) + img_response.raise_for_status() + os.makedirs(dir_path, exist_ok=True) + + filename = os.path.join(dir_path, photo.filename) + with open(filename, 'wb') as file: + file.write(img_response.content) + + +def update_caption(img_path: str): + # if the caption is a txt file, convert it to a json file + filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] + # see if it exists + if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")): + # todo add poi and what not + return # we have a json file + caption = "" + # see if txt file exists + if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")): + # read it + with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file: + caption = file.read() + # write json file + with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file: + file.write(f'{{"caption": "{caption}"}}') + + # delete txt file + os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")) + + +# def equalize_img(img_path: str): +# input_path = img_path +# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path)) +# os.makedirs(os.path.dirname(output_path), exist_ok=True) +# process_img( +# img_path=input_path, +# output_path=output_path, +# equalize=True, +# max_size=2056, +# white_balance=False, +# gamma_correction=False, +# strength=0.6, +# ) + + +# def annotate_depth(img_path: str): +# # make fake args +# args = argparse.Namespace() +# args.annotator = "midas" +# args.res = 1024 +# +# img = cv2.imread(img_path) +# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +# +# output = annotate(img, args) +# +# output = output.astype('uint8') +# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) +# +# os.makedirs(os.path.dirname(img_path), exist_ok=True) +# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path)) +# +# cv2.imwrite(output_path, output) + + +# def invert_depth(img_path: str): +# img = cv2.imread(img_path) +# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +# # invert the colors +# img = cv2.bitwise_not(img) +# +# os.makedirs(os.path.dirname(img_path), exist_ok=True) +# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path)) +# cv2.imwrite(output_path, img) + + + # + # # update our list of raw images + # raw_images = get_img_paths(raw_dir) + # + # # update raw captions + # for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"): + # update_caption(image_id) + # + # # equalize images + # for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"): + # if img_path not in eq_images: + # equalize_img(img_path) + # + # # update our list of eq images + # eq_images = get_img_paths(eq_dir) + # # update eq captions + # for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"): + # update_caption(image_id) + # + # # annotate depth + # depth_dir = os.path.join(root_dir, DEPTH_DIR) + # depth_images = get_img_paths(depth_dir) + # for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"): + # if img_path not in depth_images: + # annotate_depth(img_path) + # + # depth_images = get_img_paths(depth_dir) + # + # # invert depth + # inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR) + # inv_depth_images = get_img_paths(inv_depth_dir) + # for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"): + # if img_path not in inv_depth_images: + # invert_depth(img_path) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6dee4ad5174a2487a59f17f258d7d1c5b25e5e9 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/__init__.py @@ -0,0 +1,52 @@ +from .chroma import ChromaModel, ChromaRadianceModel +from .hidream import HidreamModel, HidreamE1Model +from .f_light import FLiteModel +from .omnigen2 import OmniGen2Model +from .flux_kontext import FluxKontextModel +from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel +from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel +from .flux2 import Flux2Model, Flux2Klein4BModel, Flux2Klein9BModel +from .z_image import ZImageModel +from .ltx2 import LTX2Model, LTX23Model +from .zeta_chroma import ZetaChromaModel +from .ernie_image import ErnieImageModel +from .nucleus_image import NucleusImageModel +from .hidream.hidream_o1_model import HidreamO1Model +from .z_image.z_image_l2p_model import ZImageL2PModel +from .ideogram4 import Ideogram4Model +from .prx_pixel_t2i import PRXPixelT2IModel +from .krea2 import Krea2Model +from .boogu_image import BooguImageModel, BooguImageEditModel + +AI_TOOLKIT_MODELS = [ + # put a list of models here + ChromaModel, + ChromaRadianceModel, + HidreamModel, + HidreamE1Model, + FLiteModel, + OmniGen2Model, + FluxKontextModel, + Wan225bModel, + Wan2214bI2VModel, + Wan2214bModel, + QwenImageModel, + QwenImageEditModel, + QwenImageEditPlusModel, + Flux2Model, + ZImageModel, + LTX2Model, + LTX23Model, + Flux2Klein4BModel, + Flux2Klein9BModel, + ZetaChromaModel, + ErnieImageModel, + NucleusImageModel, + HidreamO1Model, + ZImageL2PModel, + Ideogram4Model, + PRXPixelT2IModel, + Krea2Model, + BooguImageModel, + BooguImageEditModel, +] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e966e37298a0e5cec271e6e5375995706bb4d40e --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/__init__.py @@ -0,0 +1,4 @@ +from .boogu_image import BooguImageModel +from .boogu_image_edit import BooguImageEditModel + +__all__ = ["BooguImageModel", "BooguImageEditModel"] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/boogu_image.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/boogu_image.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5b75b6a29097cf0ce58f37cc11efe7520d8c12 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/boogu_image.py @@ -0,0 +1,459 @@ +"""Boogu-Image base (text-to-image) integration for ai-toolkit. + +Boogu-Image is a Lumina2-style mixed double-/single-stream flow-matching DiT +conditioned on Qwen3-VL instruction features. This wires up the base T2I model +for LoRA / fine-tune training and preview sampling. + +Only the base text-to-image path is implemented here (no reference-image / edit +conditioning). The architecture lives under ``./src`` (vendored & trimmed from the +upstream Boogu repo); nothing is imported from the original repo. + +Weights are pulled from the bf16 release ``Boogu/Boogu-Image-0.1-Base`` (clean +safetensors). The ``-fp8`` sibling ships torchao float8 ``.bin`` weights that +need a matching torchao/cache_dit to deserialize and is not supported here -- +use the bf16 repo and set ``quantize: true`` to run the transformer in fp8 via +ai-toolkit's own quantization. +""" + +import os +from typing import List, Optional + +import torch +import torch.nn.functional as F +import yaml +from safetensors.torch import save_file + +from transformers import AutoModel, AutoProcessor + +from toolkit.accelerator import unwrap_model +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.util.quantize import quantize, get_qtype, quantize_model + +from optimum.quanto import freeze, QTensor +from diffusers import AutoencoderKL + +from .src.transformer import BooguImageTransformer2DModel +from .src.rope import get_freqs_cis +from .src.pipeline import ( + BooguImagePipeline, + pad_instruction_features, + run_boogu_transformer, +) + + +# ai-toolkit uses CustomFlowMatchEulerDiscreteScheduler for training and (via our +# pipeline) sampling. ``shift`` warps timesteps toward the high-noise end; 3.0 is a +# reasonable high-resolution default and Boogu's own time-shift is applied in the +# preview sampler (see src/pipeline.boogu_time_schedule). +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + +# Released weights. The "-fp8" sibling ships torchao float8 weights that need +# cache_dit/torchao to deserialize; the plain repo ships clean bf16 safetensors, +# which load directly and let ai-toolkit do its own (optional) quantization. +BOOGU_BASE_PATH = "Boogu/Boogu-Image-0.1-Base" + +# System prompt the base T2I model was trained with (SYSTEM_PROMPT_4_T2I upstream). +SYSTEM_PROMPT_T2I = ( + "You are a helpful assistant that generates high-quality images based on user " + "instructions. The instructions are as follows." +) + +HF_TOKEN = os.getenv("HF_TOKEN", None) + + +def patch_qwen_vl_patch_embed(model) -> int: + """Swap Qwen-VL's vision ``patch_embed`` Conv3d for the equivalent ``F.linear``. + + Qwen-VL's patch_embed is a Conv3d whose kernel == stride, i.e. just a linear + projection of each flattened patch. bf16 Conv3d has no fast cuDNN kernel and + falls back to a slow path that effectively locks up image caching for the edit + (TI2I) model. The weight is read lazily so this survives later ``.to()`` moves. + Returns the number of patch_embed modules patched. (Vendored from + extensions_built_in/captioner/Qwen3VLCaptioner.py.) + """ + patched = 0 + for module in model.modules(): + proj = getattr(module, "proj", None) + if isinstance(proj, torch.nn.Conv3d) and tuple(proj.kernel_size) == tuple( + proj.stride + ): + + def fast_forward(hidden_states, _proj=proj): + w = _proj.weight.reshape(_proj.weight.shape[0], -1) + x = hidden_states.view(-1, w.shape[1]).to(w.dtype) + return F.linear(x, w, _proj.bias) + + module.forward = fast_forward + patched += 1 + return patched + + +class BooguImageModel(BaseModel): + arch = "boogu_image" + # Default HF repo when model.name_or_path is unset (overridden by the edit model). + default_repo = BOOGU_BASE_PATH + use_old_lokr_format = False + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["BooguImageTransformer2DModel"] + + self.patch_size = 2 + self.vae_scale_factor = 8 + # Safety cap on instruction token length (truncation only). Each caption is + # encoded at its natural length and padded to the batch max at the model + # call, so this is just an upper bound. + self.max_text_length = int( + self.model_config.model_kwargs.get("max_text_length", 1024) + ) + + # Lazily-built, resolution-independent rotary frequency tables. + self._freqs_cis = None + + @property + def text_embedding_space_version(self): + return self.arch + "_v1" + + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # 8 for the VAE downsample, 2 for the patch size. + return self.vae_scale_factor * self.patch_size + + def get_freqs_cis(self): + """Precompute (once) the per-axis rotary frequency tables for the model.""" + if self._freqs_cis is None: + cfg = unwrap_model(self.model).config + self._freqs_cis = get_freqs_cis( + cfg.axes_dim_rope, cfg.axes_lens, theta=10000 + ) + return self._freqs_cis + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Boogu-Image model") + base = self.model_config.name_or_path or self.default_repo + + # --- transformer --- + # Loads the bf16 release (clean safetensors). The "-fp8" sibling ships + # torchao float8 .bin weights that need a matching torchao/cache_dit to + # deserialize -- use the bf16 repo and let ai-toolkit quantize if wanted. + self.print_and_status_update("Loading transformer") + try: + transformer = BooguImageTransformer2DModel.from_pretrained( + base, subfolder="transformer", torch_dtype=dtype, token=HF_TOKEN + ) + except OSError as e: + raise OSError( + f"Could not load Boogu transformer safetensors from '{base}'. The " + f"'-fp8' release ships torchao float8 .bin weights, which are not " + f"supported here -- point model.name_or_path at the bf16 repo " + f"'{BOOGU_BASE_PATH}' instead." + ) from e + transformer.eval() + flush() + + # Attention defaults to torch SDPA ("native"); opt into Flash Attention 2 + # with model_kwargs.attention_backend: "flash" (needs the flash_attn pkg). + attention_backend = self.model_config.model_kwargs.get( + "attention_backend", "native" + ) + if attention_backend != "native": + transformer.set_attention_backend(attention_backend) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantize_model(self, transformer) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + # --- instruction encoder (Qwen3-VL) + processor --- + te_path = self.model_config.model_kwargs.get("text_encoder_path", base) + te_subfolder = self.model_config.model_kwargs.get( + "text_encoder_subfolder", "mllm" + ) + self.print_and_status_update("Loading Qwen3-VL instruction encoder") + processor = AutoProcessor.from_pretrained( + te_path, subfolder="processor", token=HF_TOKEN + ) + # AutoModel yields the inner Qwen3VLModel (the ``.model`` of the + # *ForConditionalGeneration), whose last_hidden_state is exactly the + # instruction feature the Boogu pipeline consumes. + text_encoder = AutoModel.from_pretrained( + te_path, subfolder=te_subfolder, torch_dtype=dtype, token=HF_TOKEN + ) + text_encoder.eval() + text_encoder.requires_grad_(False) + # The vision tower's bf16 Conv3d patch_embed has no fast kernel and stalls + # image caching for the edit model -- swap it for an equivalent F.linear. + # No-op for the base T2I model (it never runs the vision tower). + n_patched = patch_qwen_vl_patch_embed(text_encoder) + if n_patched: + self.print_and_status_update( + f" - patched {n_patched} Qwen-VL Conv3d patch_embed -> linear" + ) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing instruction encoder") + text_encoder.to(self.device_torch) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving instruction encoder to CPU") + text_encoder.to("cpu") + else: + text_encoder.to(self.device_torch) + flush() + + # --- VAE (FLUX AutoencoderKL) --- + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base, subfolder="vae", torch_dtype=self.vae_torch_dtype, token=HF_TOKEN + ) + vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + vae.eval() + vae.requires_grad_(False) + flush() + + self.noise_scheduler = BooguImageModel.get_train_scheduler() + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = processor + self.model = transformer + self.pipeline = BooguImagePipeline(self) + self.print_and_status_update("Model Loaded") + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + def get_generation_pipeline(self): + return BooguImagePipeline(self) + + def generate_single_image( + self, + pipeline: BooguImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + )[0] + return img + + # ------------------------------------------------------------------ + # Training hooks + # ------------------------------------------------------------------ + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, # (B, 16, h, w) + timestep: torch.Tensor, # 0..1000 scale (1000 = pure noise) + text_embeddings: AdvancedPromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # toolkit timestep (0..1000, 1000=noise) -> Boogu native time (0=noise, 1=clean) + t01 = timestep.to(self.device_torch, dtype=torch.float32) / 1000.0 + if t01.dim() == 0: + t01 = t01.unsqueeze(0) + if t01.shape[0] != latent_model_input.shape[0]: + t01 = t01.expand(latent_model_input.shape[0]) + boogu_t = 1.0 - t01 + + instr_feats, instr_mask = pad_instruction_features( + text_embeddings.text_embeds, self.device_torch, self.torch_dtype + ) + + # Model predicts clean - noise; negate to return the toolkit velocity + # (noise - clean), matching get_loss_target / the scheduler. + raw_velocity = run_boogu_transformer( + self.transformer, + latent_model_input.to(self.device_torch, self.torch_dtype), + boogu_t, + instr_feats, + instr_mask, + self.get_freqs_cis(), + ) + return -raw_velocity + + def get_prompt_embeds(self, prompt) -> AdvancedPromptEmbeds: + if isinstance(prompt, str): + prompt = [prompt] + + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + device = self.text_encoder.device + + # Encode each instruction at its natural length (no cross-sample padding); + # padding to a common length is deferred to the model call. The system + # prompt + chat template match the base T2I training setup. + features_list = [] + for p in prompt: + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SYSTEM_PROMPT_T2I}], + }, + {"role": "user", "content": [{"type": "text", "text": p}]}, + ] + inputs = self.tokenizer.apply_chat_template( + [messages], + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=False, + truncation=True, + max_length=self.max_text_length, + ) + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + with torch.no_grad(): + output = self.text_encoder( + input_ids=input_ids, attention_mask=attention_mask + ) + # (L, D) -- drop the batch dim, one tensor per prompt + features_list.append(output.last_hidden_state[0].to(self.torch_dtype)) + + return AdvancedPromptEmbeds(text_embeds=features_list) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + # ------------------------------------------------------------------ + # VAE + # ------------------------------------------------------------------ + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(self.vae_device_torch) + + if isinstance(image_list, list): + images = torch.stack(image_list, dim=0) + else: + images = image_list + images = images.to(device, dtype=dtype) + + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config["shift_factor"] or 0 + latents = (latents - shift) * self.vae.config["scaling_factor"] + return latents.to(device, dtype=dtype) + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(self.vae_device_torch) + + latents = latents.to(device, dtype=dtype) + shift = self.vae.config["shift_factor"] or 0 + latents = latents / self.vae.config["scaling_factor"] + shift + return self.vae.decode(latents).sample + + # ------------------------------------------------------------------ + # Saving / misc + # ------------------------------------------------------------------ + def save_model(self, output_path, meta, save_dtype): + transformer: BooguImageTransformer2DModel = unwrap_model(self.model) + transformer_dir = os.path.join(output_path, "transformer") + os.makedirs(transformer_dir, exist_ok=True) + + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + save_file( + save_dict, + os.path.join(transformer_dir, "diffusion_pytorch_model.safetensors"), + ) + # config.json so the saved transformer can be reloaded with from_pretrained. + transformer.save_config(transformer_dir) + with open(os.path.join(output_path, "aitk_meta.yaml"), "w") as f: + yaml.dump(meta, f) + + def get_base_model_version(self): + return "boogu_image.0.1" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["double_stream_layers", "single_stream_layers"] + + def convert_lora_weights_before_save(self, state_dict): + return { + k.replace("transformer.", "diffusion_model."): v + for k, v in state_dict.items() + } + + def convert_lora_weights_before_load(self, state_dict): + return { + k.replace("diffusion_model.", "transformer."): v + for k, v in state_dict.items() + } diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/boogu_image_edit.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/boogu_image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..cd910358533e6f4f76ec7680b155f4965c39dbd5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/boogu_image_edit.py @@ -0,0 +1,386 @@ +"""Boogu-Image edit (TI2I) integration for ai-toolkit. + +The edit model is the same Lumina2-style transformer + Qwen3-VL encoder as the +base T2I model, with reference-image conditioning. A reference image feeds the +model in TWO places: + + 1. Into the Qwen3-VL instruction encoder as image content alongside the edit + instruction (so the *text embeddings* already encode the reference image). + This is why ``encode_control_in_text_embeddings = True``. + 2. Into the transformer as reference-image VAE latents + (``ref_image_hidden_states``), which the ref-image refiner + double-stream + blocks attend to. + +Everything else (transformer, VAE, scheduler, time/velocity convention, saving) +is inherited from ``BooguImageModel`` -- this file only overrides the pieces +that change for TI2I. +""" + +import math +from typing import TYPE_CHECKING, List, Optional + +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision.transforms.functional import to_tensor + +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.config_modules import GenerateImageConfig, ModelConfig + +from .boogu_image import BooguImageModel +from .src.pipeline import ( + BooguImagePipeline, + pad_instruction_features, + run_boogu_transformer, +) + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + + +# Edit release (clean bf16 safetensors); same layout as the base repo. +BOOGU_EDIT_PATH = "Boogu/Boogu-Image-0.1-Edit" + +# System prompt the edit model was trained with (SYSTEM_PROMPT_4_TI2I upstream). +SYSTEM_PROMPT_TI2I = ( + "Describe the key features of the input image (color, shape, size, texture, " + "objects, background), then explain how the user's text instruction should " + "alter or modify the image. Generate a new image that meets the user's " + "requirements while maintaining consistency with the original input where " + "appropriate." +) + + +class BooguImageEditModel(BooguImageModel): + arch = "boogu_image_edit" + default_repo = BOOGU_EDIT_PATH + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + # The reference image is encoded into the Qwen3-VL instruction features, + # so get_prompt_embeds receives the control image(s). + self.encode_control_in_text_embeddings = True + # Boogu supports up to 5 reference images -> they arrive as a list. + self.has_multiple_control_images = True + # Reference images keep their own aspect/size (not resized to the target). + self.use_raw_control_images = True + + @property + def text_embedding_space_version(self): + # Distinct from the base T2I cache: the edit features fold in the ref image. + return self.arch + "_v1" + + # ------------------------------------------------------------------ + # Reference-image helpers + # ------------------------------------------------------------------ + def _vlm_resize_hw(self, h, w, max_pixels, max_side, factor=16): + """Boogu's VLM image downscale (BooguImageProcessor.get_new_height_width). + + Scale down (never up) to fit BOTH ``max_pixels`` (area) and + ``max_side_length``, then round each dim down to a multiple of ``factor`` + (the image processor's ``vae_scale_factor`` = 16 for this model). The Qwen + processor's own smart_resize runs afterwards, exactly as upstream. + """ + longest = h if h > w else w + ratio_side = max_side / longest + ratio_pixels = (max_pixels / (h * w)) ** 0.5 + ratio = min(ratio_pixels, ratio_side, 1.0) + nh = max(factor, int(h * ratio) // factor * factor) + nw = max(factor, int(w * ratio) // factor * factor) + return nh, nw + + def _ref_target_pixels(self, target_pixels: Optional[int]) -> int: + """Decide the pixel budget each reference image is resized to fit within. + + - default: ``control_image_max_pixels`` model_kwarg (1 MP) -- a hard cap so + raw, full-size control images don't blow up the token count / VRAM. + - ``match_target_res`` model_kwarg: use the target generation area instead, + matching Boogu's recommendation of ``max_input_image_pixels ~= H*W``. + """ + max_pixels = int( + self.model_config.model_kwargs.get("control_image_max_pixels", 1024 * 1024) + ) + if ( + self.model_config.model_kwargs.get("match_target_res", False) + and target_pixels + ): + return int(target_pixels) + return max_pixels + + def _encode_ref_latents( + self, control_tensors, target_pixels: Optional[int] = None + ) -> List[torch.Tensor]: + """Encode ``[0, 1]`` reference image tensors to VAE latents. + + Returns a list of ``(16, h, w)`` latents (one per reference image). Each + control image is resized so its area fits within the pixel budget (see + ``_ref_target_pixels``) -- preserving aspect ratio -- then snapped so the + latent grid is divisible by the patch size. ``control_tensors`` is a list + of ``(C, H, W)`` or ``(1, C, H, W)`` tensors in ``[0, 1]``. + """ + sc = self.get_bucket_divisibility() # 16: VAE(8) * patch(2) + budget = self._ref_target_pixels(target_pixels) + match = self.model_config.model_kwargs.get("match_target_res", False) + + latents = [] + for img in control_tensors: + if img.dim() == 3: + img = img.unsqueeze(0) + img = img.to(self.device_torch, dtype=self.torch_dtype) + + h, w = img.shape[2], img.shape[3] + # match_target_res: scale area *to* the budget; otherwise only scale + # *down* when the image is larger than the budget. + area = h * w + if match or area > budget: + ratio = h / w + new_h = math.sqrt(budget * ratio) + new_w = new_h / ratio + else: + new_h, new_w = float(h), float(w) + + # snap to a multiple of the bucket divisibility so the VAE latent grid + # is patchifiable (the transformer rearranges 2x2 latent patches). + new_h = max(sc, int(round(new_h / sc)) * sc) + new_w = max(sc, int(round(new_w / sc)) * sc) + if (new_h, new_w) != (h, w): + img = F.interpolate(img, size=(new_h, new_w), mode="bilinear") + + # encode_images expects [-1, 1]; control tensors arrive in [0, 1]. + latent = self.encode_images( + img * 2 - 1, device=self.device_torch, dtype=self.torch_dtype + ) + latents.append(latent[0]) # drop batch dim -> (16, h, w) + return latents + + def _batch_ref_latents_from_batch( + self, + batch: "DataLoaderBatchDTO", + batch_size: int, + target_pixels: Optional[int] = None, + ) -> Optional[List[List[torch.Tensor]]]: + """Build the transformer's ``ref_image_hidden_states`` from a train batch.""" + control_list = batch.control_tensor_list + if control_list is None and batch.control_tensor is not None: + control_list = [batch.control_tensor[b : b + 1] for b in range(batch_size)] + if control_list is None: + return None + if len(control_list) != batch_size: + raise ValueError("Control tensor list length does not match batch size") + return [ + self._encode_ref_latents(controls, target_pixels=target_pixels) + for controls in control_list + ] + + # ------------------------------------------------------------------ + # Conditioning + # ------------------------------------------------------------------ + def get_prompt_embeds(self, prompt, control_images=None) -> AdvancedPromptEmbeds: + if isinstance(prompt, str): + prompt = [prompt] + + if control_images is None: + raise ValueError("BooguImageEditModel requires control (reference) images") + + # Normalize to List[List[Tensor]] (per-prompt list of reference images), the + # same convention qwen_image_edit_plus uses. + if not isinstance(control_images, list): + control_images = [control_images] + if not isinstance(control_images[0], list): + control_images = [control_images] + if len(prompt) != len(control_images): + raise ValueError( + "Number of prompts must match number of control image sets" + ) + + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + device = self.text_encoder.device + + features_list = [] + for p, ctrl in zip(prompt, control_images): + # Keep reference images as tensors the whole way (no GPU->CPU->PIL + # round-trip). Match Boogu's VLM preprocessing: downscale each control + # image to fit max_pixels (384^2) AND max_side_length (768) -- the MLLM + # only needs a coarse understanding of the reference (high-res detail + # flows through the VAE ref latents), and this keeps the instruction + # sequence well under the transformer rope axes_lens (~144 tokens/ref). + max_pixels = int( + self.model_config.model_kwargs.get("vlm_max_pixels", 384 * 384) + ) + max_side = int( + self.model_config.model_kwargs.get("vlm_max_side_length", 768) + ) + images = [] + for img in ctrl: + if img.dim() == 4: + img = img[0] + img = img.to(device) + nh, nw = self._vlm_resize_hw( + img.shape[1], img.shape[2], max_pixels, max_side + ) + if (nh, nw) != (img.shape[1], img.shape[2]): + img = ( + F.interpolate( + img.unsqueeze(0), + size=(nh, nw), + mode="bicubic", + antialias=True, + ) + .squeeze(0) + .clamp(0, 1) + ) + images.append(img) + + # Build just the text template with image placeholders (tokenize=False), + # then let the processor expand the image tokens from the real grid size. + user_content = [{"type": "image"} for _ in images] + user_content.append({"type": "text", "text": p}) + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": SYSTEM_PROMPT_TI2I}], + }, + {"role": "user", "content": user_content}, + ] + text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + # do_rescale=False: control tensors are already [0, 1] (the image + # normalizer maps them to [-1, 1]). No size override -- the images are + # already at Boogu's target size, the processor just snaps to its grid. + inputs = self.tokenizer( + text=[text], + images=images, + return_tensors="pt", + do_rescale=False, + ) + model_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + v = v.to(device) + # cast image pixels to the encoder dtype; leave ids/masks as ints + if v.is_floating_point(): + v = v.to(self.torch_dtype) + model_inputs[k] = v + + with torch.no_grad(): + output = self.text_encoder(**model_inputs) + features_list.append(output.last_hidden_state[0].to(self.torch_dtype)) + + return AdvancedPromptEmbeds(text_embeds=features_list) + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, # (B, 16, h, w) + timestep: torch.Tensor, # 0..1000 scale (1000 = pure noise) + text_embeddings: AdvancedPromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + with torch.no_grad(): + # target pixel area from the noise latents (h, w are VAE-downsampled) + _, _, lh, lw = latent_model_input.shape + target_pixels = (lh * self.vae_scale_factor) * (lw * self.vae_scale_factor) + ref_latents = ( + self._batch_ref_latents_from_batch( + batch, latent_model_input.shape[0], target_pixels=target_pixels + ) + if batch is not None + else None + ) + + # toolkit timestep (0..1000, 1000=noise) -> Boogu native time (0=noise, 1=clean) + t01 = timestep.to(self.device_torch, dtype=torch.float32) / 1000.0 + if t01.dim() == 0: + t01 = t01.unsqueeze(0) + if t01.shape[0] != latent_model_input.shape[0]: + t01 = t01.expand(latent_model_input.shape[0]) + boogu_t = 1.0 - t01 + + instr_feats, instr_mask = pad_instruction_features( + text_embeddings.text_embeds, self.device_torch, self.torch_dtype + ) + + # Model predicts clean - noise; negate to return the toolkit velocity. + raw_velocity = run_boogu_transformer( + self.transformer, + latent_model_input.to(self.device_torch, self.torch_dtype), + boogu_t, + instr_feats, + instr_mask, + self.get_freqs_cis(), + ref_image_hidden_states=ref_latents, + ) + return -raw_velocity + + # ------------------------------------------------------------------ + # Sampling previews + # ------------------------------------------------------------------ + def generate_single_image( + self, + pipeline: BooguImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + # Load the reference image(s) for the transformer ref latents. The MLLM + # side already saw them (baked into conditional/unconditional embeds). + ctrl_paths = [ + p + for p in ( + gen_config.ctrl_img, + gen_config.ctrl_img_1, + gen_config.ctrl_img_2, + gen_config.ctrl_img_3, + ) + if p is not None + ] + ref_latents = None + if ctrl_paths: + ctrl_tensors = [ + to_tensor(Image.open(path).convert("RGB")) for path in ctrl_paths + ] + target_pixels = gen_config.width * gen_config.height + # one batch item (preview batch size is 1) -> List[List[(16, h, w)]] + ref_latents = [ + self._encode_ref_latents(ctrl_tensors, target_pixels=target_pixels) + ] + + img = pipeline( + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + ref_latents=ref_latents, + )[0] + return img + + def get_base_model_version(self): + return "boogu_image_edit.0.1" diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/attention_processor.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..6c524467dfd8d756b177ebede1ef3f0c5bd01c8d --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/attention_processor.py @@ -0,0 +1,491 @@ +# Vendored from the Boogu-Image repository (boogu/models/attention_processor.py). +# Original work: Copyright 2025 BAAI / OmniGen2 / HuggingFace. Apache-2.0. +# +# Attention here defaults to torch's ``scaled_dot_product_attention`` (the +# "native" backend) so the model has NO hard dependency on flash-attn. Flash +# Attention 2 is an OPTIONAL backend: each processor carries an +# ``attention_backend`` flag (set in bulk via +# ``BooguImageTransformer2DModel.set_attention_backend``) and only the "flash" +# branch touches the ``flash_attn`` package, so importing it stays lazy/guarded. +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from einops import repeat + +from .embeddings import apply_rotary_emb + +try: + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + _FLASH_ATTN_AVAILABLE = True +except ImportError: # flash-attn is optional; "native" SDPA needs none of this. + flash_attn_varlen_func = None + index_first_axis = pad_input = unpad_input = None + _FLASH_ATTN_AVAILABLE = False + +# Supported attention backends. "native" -> SDPA, "flash" -> Flash Attention 2. +ATTENTION_BACKENDS = ("native", "flash") + + +def _get_unpad_data(mask_2d: torch.Tensor): + """Indices / cu_seqlens / max_seqlen from a 2D padding mask [B, L].""" + seqlens_in_batch = mask_2d.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + +def _upad_input(query, key, value, attention_mask, query_length, num_heads): + """Unpad q/k/v for ``flash_attn_varlen_func`` given a [B, L] padding mask.""" + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key.shape + + key = index_first_axis( + key.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value = index_first_axis( + value.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query = index_first_axis( + query.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query.device + ) + indices_q = cu_seqlens_q[:-1] + query = query.squeeze(1) + else: + q_mask = attention_mask[:, -query_length:] + query, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query, q_mask + ) + + return ( + query, + key, + value, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def _flash_varlen_attention(query, key, value, attention_mask, attn, softmax_scale): + """Run flash-attn varlen over a [B, L, heads, head_dim] q/k/v with a 2D mask. + + Returns the attention output flattened back to [B, L, heads * head_dim]. + """ + batch_size, sequence_length = query.shape[0], query.shape[1] + kv_heads = key.shape[2] + + mask_2d = attention_mask.bool() if attention_mask is not None else None + ( + query_states, + key_states, + value_states, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_q, max_seqlen_k), + ) = _upad_input(query, key, value, mask_2d, sequence_length, attn.heads) + + if kv_heads < attn.heads: + key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + value_states = repeat( + value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads + ) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) + return hidden_states.flatten(-2) + + +class BooguImageDoubleStreamSelfAttnProcessor(nn.Module): + """ + Double-stream self-attention processor. + + Instruction and image features each get their own q/k/v projections; the two + streams are concatenated (instruction first), attended jointly, then split + back and projected with separate output heads. Uses torch SDPA by default; + set ``attention_backend = "flash"`` for Flash Attention 2. + """ + + def __init__( + self, + head_dim: int, + num_attention_heads: int, + num_kv_heads: int, + qkv_bias: bool = False, + ) -> None: + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "BooguImageDoubleStreamSelfAttnProcessor requires PyTorch 2.0+." + ) + + self.head_dim = head_dim + self.num_attention_heads = num_attention_heads + self.num_kv_heads = num_kv_heads + self.attention_backend = "native" + + query_dim = head_dim * num_attention_heads + kv_dim = head_dim * num_kv_heads + + self.img_to_q = nn.Linear(query_dim, query_dim, bias=qkv_bias) + self.img_to_k = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + self.img_to_v = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + + self.instruct_to_q = nn.Linear(query_dim, query_dim, bias=qkv_bias) + self.instruct_to_k = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + self.instruct_to_v = nn.Linear(query_dim, kv_dim, bias=qkv_bias) + + self.instruct_out = nn.Linear(query_dim, query_dim, bias=qkv_bias) + self.img_out = nn.Linear(query_dim, query_dim, bias=qkv_bias) + + self.initialize_weights() + + def initialize_weights(self) -> None: + nn.init.xavier_uniform_(self.img_to_q.weight) + nn.init.xavier_uniform_(self.img_to_k.weight) + nn.init.xavier_uniform_(self.img_to_v.weight) + nn.init.xavier_uniform_(self.instruct_to_q.weight) + nn.init.xavier_uniform_(self.instruct_to_k.weight) + nn.init.xavier_uniform_(self.instruct_to_v.weight) + nn.init.xavier_uniform_(self.instruct_out.weight) + nn.init.xavier_uniform_(self.img_out.weight) + + if self.img_to_q.bias is not None: + nn.init.zeros_(self.img_to_q.bias) + nn.init.zeros_(self.img_to_k.bias) + nn.init.zeros_(self.img_to_v.bias) + nn.init.zeros_(self.instruct_to_q.bias) + nn.init.zeros_(self.instruct_to_k.bias) + nn.init.zeros_(self.instruct_to_v.bias) + nn.init.zeros_(self.instruct_out.bias) + nn.init.zeros_(self.img_out.bias) + + def _concat_instruction_image_features( + self, + img_hidden_states_list: List[torch.Tensor], + instruct_hidden_states_list: List[torch.Tensor], + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> List[torch.Tensor]: + """Concatenate instruction then image features into one joint sequence.""" + batch_size = img_hidden_states_list[0].shape[0] + max_seq_len = max(seq_lengths) + + concatenated_list = [] + for img_tensor, instruct_tensor in zip( + img_hidden_states_list, instruct_hidden_states_list + ): + device = img_tensor.device + if instruct_tensor.device != device: + instruct_tensor = instruct_tensor.to(device) + + feature_dim = img_tensor.shape[-1] + concatenated = img_tensor.new_zeros(batch_size, max_seq_len, feature_dim) + + for i, (encoder_seq_len, seq_len) in enumerate( + zip(encoder_seq_lengths, seq_lengths) + ): + concatenated[i, :encoder_seq_len] = instruct_tensor[i, :encoder_seq_len] + concatenated[i, encoder_seq_len:seq_len] = img_tensor[ + i, : seq_len - encoder_seq_len + ] + + concatenated_list.append(concatenated) + + return concatenated_list + + def _split_instruction_image_features( + self, + hidden_states_list: List[torch.Tensor], + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Inverse of ``_concat_instruction_image_features``.""" + result_list = [] + for hidden_states in hidden_states_list: + batch_size = hidden_states.shape[0] + feature_dim = hidden_states.shape[-1] + + max_instruct_len = max(encoder_seq_lengths) + max_img_len = max( + seq_len - encoder_seq_len + for seq_len, encoder_seq_len in zip(seq_lengths, encoder_seq_lengths) + ) + + instruct_hidden_states = hidden_states.new_zeros( + batch_size, max_instruct_len, feature_dim + ) + img_hidden_states = hidden_states.new_zeros( + batch_size, max_img_len, feature_dim + ) + + for i, (encoder_seq_len, seq_len) in enumerate( + zip(encoder_seq_lengths, seq_lengths) + ): + img_len = seq_len - encoder_seq_len + instruct_hidden_states[i, :encoder_seq_len] = hidden_states[ + i, :encoder_seq_len + ] + img_hidden_states[i, :img_len] = hidden_states[ + i, encoder_seq_len:seq_len + ] + + result_list.append((instruct_hidden_states, img_hidden_states)) + + return result_list + + def __call__( + self, + attn: Attention, + img_hidden_states: torch.Tensor, + instruct_hidden_states: torch.Tensor, + joint_attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + encoder_seq_lengths: List[int] = None, + seq_lengths: List[int] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + batch_size = img_hidden_states.shape[0] + + img_query = self.img_to_q(img_hidden_states) + img_key = self.img_to_k(img_hidden_states) + img_value = self.img_to_v(img_hidden_states) + + instruct_query = self.instruct_to_q(instruct_hidden_states) + instruct_key = self.instruct_to_k(instruct_hidden_states) + instruct_value = self.instruct_to_v(instruct_hidden_states) + + img_list = [img_query, img_key, img_value] + instruct_list = [instruct_query, instruct_key, instruct_value] + concatenated_list = self._concat_instruction_image_features( + img_list, instruct_list, encoder_seq_lengths, seq_lengths + ) + query, key, value = concatenated_list + + sequence_length = max(seq_lengths) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb, use_real=False) + key = apply_rotary_emb(key, rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + if base_sequence_length is not None: + softmax_scale = ( + math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + ) + else: + softmax_scale = attn.scale + + if self.attention_backend == "flash": + # q/k/v are [B, L, heads, head_dim]; the joint padding mask is 2D. + hidden_states = _flash_varlen_attention( + query, key, value, joint_attention_mask, attn, softmax_scale + ) + hidden_states = hidden_states.type_as(query) + else: + if joint_attention_mask is not None: + joint_attention_mask = joint_attention_mask.bool() + if joint_attention_mask.dim() == 2: + joint_attention_mask = joint_attention_mask.view( + batch_size, 1, 1, -1 + ) + elif joint_attention_mask.dim() == 3: + joint_attention_mask = joint_attention_mask.unsqueeze(1) + else: + raise ValueError( + f"Unsupported joint_attention_mask shape: {joint_attention_mask.shape}" + ) + + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # explicitly repeat key/value to avoid the slow MATH SDPA backend that + # enable_gqa triggers on some torch builds + k = k.repeat_interleave(q.size(-3) // k.size(-3), -3) + v = v.repeat_interleave(q.size(-3) // v.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + q, k, v, attn_mask=joint_attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.type_as(query) + + split_results = self._split_instruction_image_features( + [hidden_states], encoder_seq_lengths, seq_lengths + ) + instruct_hidden_states, img_hidden_states = split_results[0] + + instruct_projected = self.instruct_out(instruct_hidden_states) + img_projected = self.img_out(img_hidden_states) + + merged_list = self._concat_instruction_image_features( + [img_projected], [instruct_projected], encoder_seq_lengths, seq_lengths + ) + hidden_states = merged_list[0] + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class BooguImageAttnProcessor: + """ + Single-stream self-attention processor with RoPE + QK norm. + + Uses torch SDPA by default; set ``attention_backend = "flash"`` for Flash + Attention 2 (requires the ``flash_attn`` package). + """ + + def __init__(self) -> None: + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("BooguImageAttnProcessor requires PyTorch 2.0+.") + self.attention_backend = "native" + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + if base_sequence_length is not None: + softmax_scale = ( + math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + ) + else: + softmax_scale = attn.scale + + if self.attention_backend == "flash" and ( + attention_mask is None or attention_mask.dim() == 2 + ): + mask = ( + attention_mask + if attention_mask is not None + else query.new_ones(batch_size, sequence_length, dtype=torch.bool) + ) + hidden_states = _flash_varlen_attention( + query, key, value, mask, attn, softmax_scale + ) + hidden_states = hidden_states.type_as(query) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + if attention_mask is not None: + attention_mask = attention_mask.bool() + if attention_mask.dim() == 2: + attention_mask = attention_mask.view(batch_size, 1, 1, -1) + elif attention_mask.dim() == 3: + B, L, _ = attention_mask.shape + diag_valid = torch.diagonal(attention_mask, dim1=-2, dim2=-1) + lengths = diag_valid.sum(dim=-1) + arange_L = torch.arange(L, device=attention_mask.device) + q_valid = arange_L.unsqueeze(0) < lengths.unsqueeze(1) + k_valid = q_valid + causal = torch.tril( + torch.ones(L, L, dtype=torch.bool, device=attention_mask.device) + ) + combined = causal & q_valid.unsqueeze(-1) & k_valid.unsqueeze(-2) + attention_mask = combined.unsqueeze(1) + else: + raise ValueError( + f"Unsupported attention_mask shape: {attention_mask.shape}" + ) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/block_lumina2.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/block_lumina2.py new file mode 100644 index 0000000000000000000000000000000000000000..d243d58193c98b57913c9c92e84482309f4ea6f1 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/block_lumina2.py @@ -0,0 +1,164 @@ +# Vendored from the Boogu-Image repository (boogu/models/transformers/block_lumina2.py). +# Original work: Copyright 2025 BAAI / OmniGen2 / HuggingFace. Apache-2.0. +# +# The optional triton RMSNorm and flash-attn SwiGLU fast paths are dropped here; +# we always use torch.nn.RMSNorm and a plain SwiGLU so the model runs anywhere. +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.embeddings import Timesteps +from torch.nn import RMSNorm + +from .embeddings import TimestepEmbedding + + +def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y + + +class LuminaRMSNormZero(nn.Module): + """Adaptive RMS normalization with a zero-initialized modulation projection.""" + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm( + embedding_dim, eps=eps, elementwise_affine=elementwise_affine + ) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + """A SwiGLU feed-forward layer with a multiple-of-256 inner dim.""" + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear(dim, inner_dim, bias=False) + self.linear_2 = nn.Linear(inner_dim, dim, bias=False) + self.linear_3 = nn.Linear(dim, inner_dim, bias=False) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + instruction_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + scale=timestep_scale, + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(instruction_feat_dim, eps=norm_eps), + nn.Linear(instruction_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, + timestep: torch.Tensor, + instruction_hidden_states: torch.Tensor, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(instruction_hidden_states) + return time_embed, caption_embed diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/embeddings.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..f89429e70702f05a05fdb6ca2a661473c7725cc8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/embeddings.py @@ -0,0 +1,112 @@ +# Vendored from the Boogu-Image repository (boogu/models/embeddings.py). +# Original work: Copyright 2024 The HuggingFace Team. Apache-2.0. +# +# Only the pieces the Boogu transformer actually needs are kept here: +# ``TimestepEmbedding`` and ``apply_rotary_emb``. +from typing import Optional, Tuple, Union + +import torch +from diffusers.models.activations import get_activation +from torch import nn + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.zeros_(self.linear_1.bias) + nn.init.normal_(self.linear_2.weight, std=0.02) + nn.init.zeros_(self.linear_2.bias) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + Boogu always calls this with ``use_real=False`` (the Lumina-style complex + path): ``freqs_cis`` is a complex tensor and ``x`` is reinterpreted as + complex, multiplied, and returned as real. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, Boogu and CogView4 + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError( + f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2." + ) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina / boogu + x_rotated = torch.view_as_complex( + x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2) + ) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..780af4f20b3735d56d9deee7680c87160b3376f5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/pipeline.py @@ -0,0 +1,231 @@ +"""Packing / sampling helpers for Boogu-Image (base T2I). + +This module glues the Qwen3-VL instruction features and the image latents into +the call the Boogu transformer expects, and provides a minimal flow-matching +sampler used to render preview images during training. + +Time convention +--------------- +Boogu's native flow time is ``t in [0, 1]`` with ``t=0`` pure noise and ``t=1`` +clean; the transformer predicts ``clean - noise``. ai-toolkit's scheduler uses +the opposite convention (``t=1`` noise, velocity ``noise - clean``). The +conversion lives in ``BooguImageModel.get_noise_prediction``; this sampler runs +entirely in Boogu's native domain via :func:`run_boogu_transformer`. +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +import numpy as np +import torch +from PIL import Image +from diffusers.utils.torch_utils import randn_tensor + +from .transformer import BooguImageTransformer2DModel + + +# --------------------------------------------------------------------------- +# Instruction feature padding. +# --------------------------------------------------------------------------- + + +def pad_instruction_features( + features_list: List[torch.Tensor], + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Right-pad per-sample ``(L_i, D)`` instruction features into a batch. + + Captions are stored per-sample at their natural length and only padded to the + batch max here, right before the model call. Returns ``(features (B, L, D), + attention_mask (B, L))`` with the mask 1 for real tokens, 0 for padding. + """ + lengths = [f.shape[0] for f in features_list] + max_len = max(lengths) + dim = features_list[0].shape[-1] + batch_size = len(features_list) + + features = torch.zeros(batch_size, max_len, dim, device=device, dtype=dtype) + mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) + for i, f in enumerate(features_list): + n = f.shape[0] + features[i, :n] = f.to(device, dtype) + mask[i, :n] = 1 + return features, mask + + +# --------------------------------------------------------------------------- +# Time-shift schedule (mirrors the released Boogu base scheduler: v1 shift). +# --------------------------------------------------------------------------- + + +def _lin_shift( + num_tokens: float, + x1: float = 256.0, + y1: float = 0.5, + x2: float = 4096.0, + y2: float = 1.15, +) -> float: + """Linear token-count -> mu mapping (Boogu base_shift/max_shift defaults).""" + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return m * num_tokens + b + + +def boogu_time_schedule( + num_steps: int, + num_patch_tokens: int, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Boogu native-domain timesteps (0=noise .. 1=clean) with v1 time shift. + + Returns a length ``num_steps + 1`` tensor; the trailing ``1.0`` is the clean + endpoint, matching the ``_timesteps`` tail in the reference scheduler. + """ + t_arr = np.linspace(0.0, 1.0, num_steps + 1, dtype=np.float32)[:-1] + + mu = _lin_shift(max(1, int(num_patch_tokens))) + eps = 1e-8 + t1 = np.clip(1.0 - t_arr, eps, 1.0 - eps) + num = math.exp(mu) + denom = num + (1.0 / t1 - 1.0) + t_arr = (1.0 - num / denom).astype(np.float32) + + times = np.concatenate([t_arr, np.ones(1, dtype=np.float32)]) + return torch.from_numpy(times).to(device=device, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# Transformer call (Boogu native time domain). +# --------------------------------------------------------------------------- + + +def run_boogu_transformer( + transformer: BooguImageTransformer2DModel, + latents: torch.Tensor, # (B, 16, H, W) + boogu_t: torch.Tensor, # (B,) in [0, 1], 0=noise, 1=clean + instruction_features: torch.Tensor, # (B, L, instruction_feat_dim) + instruction_mask: torch.Tensor, # (B, L) 1 for real tokens + freqs_cis, # precomputed per-axis rotary tables + ref_image_hidden_states=None, # edit/TI2I: List[List[(16, H, W)]] per batch item +) -> torch.Tensor: + """Run the transformer and return the raw model velocity (``clean - noise``). + + Shapes pass straight through: the prediction comes back as ``(B, 16, H, W)`` + in the same latent layout as ``latents``. ``ref_image_hidden_states`` stays + ``None`` for the base T2I model and carries reference-image VAE latents for + the edit (TI2I) model. + """ + out = transformer( + hidden_states=latents, + timestep=boogu_t, + instruction_hidden_states=instruction_features, + freqs_cis=freqs_cis, + instruction_attention_mask=instruction_mask, + ref_image_hidden_states=ref_image_hidden_states, + return_dict=False, + ) + return out + + +# --------------------------------------------------------------------------- +# Minimal sampling pipeline (for training previews). +# --------------------------------------------------------------------------- + + +class BooguImagePipeline: + """Lightweight flow-matching sampler used by ai-toolkit's preview generation.""" + + def __init__(self, model): + # ``model`` is the BooguImageModel so we can reuse its encode/decode and + # latent helpers without duplicating state. + self.model = model + + @property + def device(self): + return self.model.device_torch + + def to(self, *args, **kwargs): + return self + + @torch.no_grad() + def __call__( + self, + conditional_embeds, + unconditional_embeds, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 4.0, + latents: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ref_latents=None, # edit/TI2I: List[List[(16, H, W)]] reference VAE latents + **kwargs, + ) -> List[Image.Image]: + model = self.model + device = model.device_torch + dtype = model.torch_dtype + transformer = model.transformer + patch = model.patch_size + ae_scale = model.vae_scale_factor # 8 + + latent_channels = transformer.config.in_channels + h_lat = height // ae_scale + w_lat = width // ae_scale + num_patch_tokens = (h_lat // patch) * (w_lat // patch) + + freqs_cis = model.get_freqs_cis() + + do_cfg = guidance_scale > 1.0 + + if latents is None: + shape = (1, latent_channels, h_lat, w_lat) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=torch.float32 + ) + # In Boogu's domain t=0 is pure noise, so the initial latent IS the noise. + latents = latents.to(device, dtype=torch.float32) + + cond_feats, cond_mask = pad_instruction_features( + conditional_embeds.text_embeds, device, dtype + ) + if do_cfg: + uncond_feats, uncond_mask = pad_instruction_features( + unconditional_embeds.text_embeds, device, dtype + ) + + times = boogu_time_schedule(num_inference_steps, num_patch_tokens, device) + + for t, t_next in zip(times[:-1], times[1:]): + boogu_t = t.expand(latents.shape[0]) + v_cond = run_boogu_transformer( + transformer, + latents.to(dtype), + boogu_t, + cond_feats, + cond_mask, + freqs_cis, + ref_image_hidden_states=ref_latents, + ) + if do_cfg: + v_uncond = run_boogu_transformer( + transformer, + latents.to(dtype), + boogu_t, + uncond_feats, + uncond_mask, + freqs_cis, + ref_image_hidden_states=ref_latents, + ) + v = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v = v_cond + latents = latents + v.to(torch.float32) * (t_next - t) + + images = model.decode_latents(latents, device=device, dtype=dtype) + images = images.float().clamp(-1.0, 1.0) + images = ((images + 1.0) * 127.5).round().to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy() + return [Image.fromarray(arr) for arr in images] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/rope.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..018445470a3a331e5424d0af5f50c80b7eb2fa2a --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/rope.py @@ -0,0 +1,244 @@ +# Vendored from the Boogu-Image repository (boogu/models/transformers/rope.py). +# Original work: Copyright 2025 BAAI / OmniGen2 / HuggingFace. Apache-2.0. +# +# Only the double-stream rotary embedder (the one the transformer uses) and the +# ``get_freqs_cis`` precompute helper are kept. The MPS-specific branch is +# preserved verbatim. +from typing import List, Tuple + +import torch +import torch.nn as nn +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from einops import repeat + + +def get_freqs_cis( + axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int +) -> List[torch.Tensor]: + """Precompute the per-axis rotary frequency tables (done once per resolution).""" + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for d, e in zip(axes_dim, axes_lens): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + +class BooguImageDoubleStreamRotaryPosEmbed(nn.Module): + def __init__( + self, + theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2, + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis( + axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int + ) -> List[torch.Tensor]: + return get_freqs_cis(axes_dim, axes_lens, theta) + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append( + torch.gather( + freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index + ) + ) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + seq_lengths = [ + cap_len + sum(ref_img_len) + img_len + for cap_len, ref_img_len, img_len in zip( + l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len + ) + ] + + max_seq_len = max(seq_lengths) + max_ref_img_len = max( + [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len] + ) + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros( + batch_size, max_seq_len, 3, dtype=torch.int32, device=device + ) + + for i, (cap_seq_len, seq_len) in enumerate( + zip(l_effective_cap_len, seq_lengths) + ): + # add text position ids + position_ids[i, :cap_seq_len] = repeat( + torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3" + ) + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip( + ref_img_sizes[i], l_effective_ref_img_len[i] + ): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + + row_ids = repeat( + torch.arange(ref_H_tokens, dtype=torch.int32, device=device), + "h -> h w", + w=ref_W_tokens, + ).flatten() + col_ids = repeat( + torch.arange(ref_W_tokens, dtype=torch.int32, device=device), + "w -> h w", + h=ref_H_tokens, + ).flatten() + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 0] = ( + pe_shift + ) + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 1] = ( + row_ids + ) + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 2] = ( + col_ids + ) + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat( + torch.arange(H_tokens, dtype=torch.int32, device=device), + "h -> h w", + w=W_tokens, + ).flatten() + col_ids = repeat( + torch.arange(W_tokens, dtype=torch.int32, device=device), + "w -> h w", + h=H_tokens, + ).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len:seq_len, 0] = pe_shift + position_ids[i, pe_shift_len:seq_len, 1] = row_ids + position_ids[i, pe_shift_len:seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + ref_img_freqs_cis = torch.zeros( + batch_size, + max_ref_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + batch_size, + max_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + # Calculate combined image sequence lengths (ref_img + img) for each sample + combined_img_seq_lengths = [ + sum(ref_img_len) + img_len + for ref_img_len, img_len in zip( + l_effective_ref_img_len, l_effective_img_len + ) + ] + max_combined_img_len = max(combined_img_seq_lengths) + + # Create combined image rotary embeddings + combined_img_freqs_cis = torch.zeros( + batch_size, + max_combined_img_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate( + zip( + l_effective_cap_len, + l_effective_ref_img_len, + l_effective_img_len, + seq_lengths, + ) + ): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[ + i, cap_seq_len : cap_seq_len + sum(ref_img_len) + ] + img_freqs_cis[i, :img_len] = freqs_cis[ + i, + cap_seq_len + sum(ref_img_len) : cap_seq_len + + sum(ref_img_len) + + img_len, + ] + + # Combined image rotary embeddings: ref_img + img (same order as img_patch_embed_and_refine) + combined_img_freqs_cis[i, : sum(ref_img_len)] = freqs_cis[ + i, cap_seq_len : cap_seq_len + sum(ref_img_len) + ] + combined_img_freqs_cis[i, sum(ref_img_len) : sum(ref_img_len) + img_len] = ( + freqs_cis[ + i, + cap_seq_len + sum(ref_img_len) : cap_seq_len + + sum(ref_img_len) + + img_len, + ] + ) + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + combined_img_freqs_cis, + combined_img_seq_lengths, + ) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/transformer.py b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ce93406a214ddb07e1175716bb488c6874f17d85 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/boogu_image/src/transformer.py @@ -0,0 +1,1174 @@ +"""Boogu-Image transformer (vendored & trimmed for ai-toolkit). + +Adapted from the Boogu-Image repository +(boogu/models/transformers/transformer_boogu.py). Apache-2.0. + +Differences from the upstream file: + * The TeaCache / TaylorSeer inference caches are removed (training/finetuning + never use them), along with the triton RMSNorm and flash-attn fast paths. + * Prompt-tuning (``PromptEmbedding``) is dropped -- the base model does not use it. + * Gradient checkpointing is wired through every heavy block stack (the refiner + loops as well as the double-/single-stream loops) and is gated on + ``torch.is_grad_enabled()`` so it is a no-op during sampling/inference. + +The mixed-stream topology, weight names and numerics are otherwise identical to +upstream so the released checkpoints load unchanged. +""" + +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from einops import rearrange +from torch.nn import RMSNorm + +from .attention_processor import ( + ATTENTION_BACKENDS, + _FLASH_ATTN_AVAILABLE, + BooguImageAttnProcessor, + BooguImageDoubleStreamSelfAttnProcessor, +) +from .block_lumina2 import ( + Lumina2CombinedTimestepCaptionEmbedding, + LuminaFeedForward, + LuminaLayerNormContinuous, + LuminaRMSNormZero, +) +from .rope import BooguImageDoubleStreamRotaryPosEmbed + +logger = logging.get_logger(__name__) + + +class BooguImageTransformerBlock(nn.Module): + """Basic Boogu-Image transformer block: attention + SwiGLU MLP + RMSNorm.""" + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=BooguImageAttnProcessor(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1( + hidden_states, temb + ) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2( + attn_output + ) + mlp_output = self.feed_forward( + self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + ) + hidden_states = hidden_states + gate_mlp.unsqueeze( + 1 + ).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class BooguImageNoiseRefinerTransformerBlock(BooguImageTransformerBlock): + pass + + +class BooguImageRefImgRefinerTransformerBlock(BooguImageTransformerBlock): + pass + + +class BooguImageContextRefinerTransformerBlock(BooguImageTransformerBlock): + pass + + +class BooguImageSingleStreamTransformerBlock(BooguImageTransformerBlock): + pass + + +class BooguImageDoubleStreamTransformerBlock(nn.Module): + """Boogu-Image double-stream block: instruction & image tokens in parallel streams.""" + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + self.num_attention_heads = num_attention_heads + self.modulation = modulation + self.hidden_size = dim + + double_stream_processor = BooguImageDoubleStreamSelfAttnProcessor( + head_dim=self.head_dim, + num_attention_heads=num_attention_heads, + num_kv_heads=num_kv_heads, + qkv_bias=False, + ) + + # Image stream components. + self.img_instruct_attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=double_stream_processor, + ) + + self.img_self_attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=BooguImageAttnProcessor(), + ) + + self.img_feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + self.img_norm1 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + self.img_norm2 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + self.img_norm3 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + else: + self.img_norm1 = RMSNorm(dim, eps=norm_eps) + self.img_norm2 = RMSNorm(dim, eps=norm_eps) + self.img_norm3 = RMSNorm(dim, eps=norm_eps) + + self.img_ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.img_attn_norm = RMSNorm(dim, eps=norm_eps) + self.img_self_attn_norm = RMSNorm(dim, eps=norm_eps) + self.img_ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + # Instruction stream components. + self.instruct_feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + self.instruct_norm1 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + self.instruct_norm2 = LuminaRMSNormZero( + embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True + ) + else: + self.instruct_norm1 = RMSNorm(dim, eps=norm_eps) + self.instruct_norm2 = RMSNorm(dim, eps=norm_eps) + + self.instruct_ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.instruct_attn_norm = RMSNorm(dim, eps=norm_eps) + self.instruct_ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + # double_stream_processor owns its own q/k/v projections, so the wrapping + # Attention's q/k/v are unused -- drop them so they aren't saved/loaded. + for param in self.img_instruct_attn.to_q.parameters(): + param.requires_grad = False + for param in self.img_instruct_attn.to_k.parameters(): + param.requires_grad = False + for param in self.img_instruct_attn.to_v.parameters(): + param.requires_grad = False + + del self.img_instruct_attn.to_k + del self.img_instruct_attn.to_v + del self.img_instruct_attn.to_q + + def initialize_weights(self) -> None: + nn.init.xavier_uniform_(self.img_instruct_attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.img_self_attn.to_q.weight) + nn.init.xavier_uniform_(self.img_self_attn.to_k.weight) + nn.init.xavier_uniform_(self.img_self_attn.to_v.weight) + nn.init.xavier_uniform_(self.img_self_attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.img_feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.img_feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.img_feed_forward.linear_3.weight) + + nn.init.xavier_uniform_(self.instruct_feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.instruct_feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.instruct_feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.img_norm1.linear.weight) + nn.init.zeros_(self.img_norm1.linear.bias) + nn.init.zeros_(self.img_norm2.linear.weight) + nn.init.zeros_(self.img_norm2.linear.bias) + nn.init.zeros_(self.img_norm3.linear.weight) + nn.init.zeros_(self.img_norm3.linear.bias) + + nn.init.zeros_(self.instruct_norm1.linear.weight) + nn.init.zeros_(self.instruct_norm1.linear.bias) + nn.init.zeros_(self.instruct_norm2.linear.weight) + nn.init.zeros_(self.instruct_norm2.linear.bias) + + def forward( + self, + img_hidden_states: torch.Tensor, # [B, L_img, D] + instruct_hidden_states: torch.Tensor, # [B, L_instruct, D] + img_attention_mask: torch.Tensor, # [B, L_img] + joint_attention_mask: torch.Tensor, # [B, L_total] + image_rotary_emb: torch.Tensor, # [B, L_img, head_dim] + rotary_emb: torch.Tensor, # [B, L_total, head_dim] + temb: Optional[torch.Tensor] = None, # [B, 1024] + encoder_seq_lengths: List[int] = None, # [B] + seq_lengths: List[int] = None, # [B] + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.modulation and temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + batch_size = img_hidden_states.shape[0] + L_instruct = instruct_hidden_states.shape[1] + L_img = img_hidden_states.shape[1] + + if self.modulation: + # Step 1: modulation for both streams. + img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1( + img_hidden_states, temb + ) + img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb) + img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb) + + ( + instruct_norm1_out, + instruct_gate_msa, + instruct_scale_mlp, + instruct_gate_mlp, + ) = self.instruct_norm1(instruct_hidden_states, temb) + instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2( + instruct_hidden_states, temb + ) + + # Step 2: joint attention on [instruct + img]. + joint_attn_out = self.img_instruct_attn.processor( + attn=self.img_instruct_attn, + img_hidden_states=img_norm1_out, + instruct_hidden_states=instruct_norm1_out, + joint_attention_mask=joint_attention_mask, + rotary_emb=rotary_emb, + encoder_seq_lengths=encoder_seq_lengths, + seq_lengths=seq_lengths, + ) + + instruct_attn_out = instruct_hidden_states.new_zeros( + batch_size, L_instruct, self.hidden_size + ) + img_attn_out = img_hidden_states.new_zeros( + batch_size, L_img, self.hidden_size + ) + for i, (encoder_seq_len, seq_len) in enumerate( + zip(encoder_seq_lengths, seq_lengths) + ): + instruct_attn_out[i, :encoder_seq_len] = joint_attn_out[ + i, :encoder_seq_len + ] + img_attn_out[i, : seq_len - encoder_seq_len] = joint_attn_out[ + i, encoder_seq_len:seq_len + ] + + # Step 3: image self-attention. + img_self_attn_out = self.img_self_attn( + hidden_states=img_norm3_out, + encoder_hidden_states=img_norm3_out, + attention_mask=img_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # Step 4: residual updates. + img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze( + 1 + ).tanh() * self.img_attn_norm(img_attn_out) + img_hidden_states = img_hidden_states + img_gate_self.unsqueeze( + 1 + ).tanh() * self.img_self_attn_norm(img_self_attn_out) + + img_mlp_input = ( + 1 + img_scale_mlp.unsqueeze(1) + ) * img_norm2_out + img_shift_mlp.unsqueeze(1) + img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input)) + img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze( + 1 + ).tanh() * self.img_ffn_norm2(img_mlp_out) + + instruct_hidden_states = ( + instruct_hidden_states + + instruct_gate_msa.unsqueeze(1).tanh() + * self.instruct_attn_norm(instruct_attn_out) + ) + + instruct_mlp_input = ( + 1 + instruct_scale_mlp.unsqueeze(1) + ) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1) + instruct_mlp_out = self.instruct_feed_forward( + self.instruct_ffn_norm1(instruct_mlp_input) + ) + instruct_hidden_states = ( + instruct_hidden_states + + instruct_gate_mlp.unsqueeze(1).tanh() + * self.instruct_ffn_norm2(instruct_mlp_out) + ) + + else: + # Non-modulated branch used by context-style blocks. + img_norm1_out = self.img_norm1(img_hidden_states) + img_norm3_out = self.img_norm3(img_hidden_states) + instruct_norm1_out = self.instruct_norm1(instruct_hidden_states) + + joint_attn_out = self.img_instruct_attn.processor( + attn=self.img_instruct_attn, + img_hidden_states=img_norm1_out, + instruct_hidden_states=instruct_norm1_out, + joint_attention_mask=joint_attention_mask, + rotary_emb=rotary_emb, + encoder_seq_lengths=encoder_seq_lengths, + seq_lengths=seq_lengths, + ) + + instruct_attn_out = instruct_hidden_states.new_zeros( + batch_size, L_instruct, self.hidden_size + ) + img_attn_out = img_hidden_states.new_zeros( + batch_size, L_img, self.hidden_size + ) + for i, (encoder_seq_len, seq_len) in enumerate( + zip(encoder_seq_lengths, seq_lengths) + ): + instruct_attn_out[i, :encoder_seq_len] = joint_attn_out[ + i, :encoder_seq_len + ] + img_attn_out[i, : seq_len - encoder_seq_len] = joint_attn_out[ + i, encoder_seq_len:seq_len + ] + + img_self_attn_out = self.img_self_attn( + hidden_states=img_norm3_out, + encoder_hidden_states=img_norm3_out, + attention_mask=img_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + img_hidden_states = img_hidden_states + self.img_attn_norm(img_attn_out) + img_hidden_states = img_hidden_states + self.img_self_attn_norm( + img_self_attn_out + ) + img_norm2_out = self.img_norm2(img_hidden_states) + img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_norm2_out)) + img_hidden_states = img_hidden_states + self.img_ffn_norm2(img_mlp_out) + + instruct_hidden_states = instruct_hidden_states + self.instruct_attn_norm( + instruct_attn_out + ) + instruct_norm2_out = self.instruct_norm2(instruct_hidden_states) + instruct_mlp_out = self.instruct_feed_forward( + self.instruct_ffn_norm1(instruct_norm2_out) + ) + instruct_hidden_states = instruct_hidden_states + self.instruct_ffn_norm2( + instruct_mlp_out + ) + + return img_hidden_states, instruct_hidden_states + + +class BooguImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): + """Boogu-Image transformer with mixed double-stream -> single-stream topology.""" + + _supports_gradient_checkpointing = True + _no_split_modules = [ + "BooguImageTransformerBlock", + "BooguImageNoiseRefinerTransformerBlock", + "BooguImageRefImgRefinerTransformerBlock", + "BooguImageContextRefinerTransformerBlock", + "BooguImageSingleStreamTransformerBlock", + "BooguImageDoubleStreamTransformerBlock", + ] + _skip_layerwise_casting_patterns = ["x_embedder", "norm", "embedding"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 3360, + num_layers: int = 40, + num_double_stream_layers: int = 8, + num_refiner_layers: int = 2, + num_attention_heads: int = 28, + num_kv_heads: int = 7, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (40, 40, 40), + axes_lens: Tuple[int, int, int] = (2048, 1664, 1664), + instruction_feature_configs: Dict[str, Any] = dict( + instruction_feat_dim=4096, + reduce_type="mean", + num_instruction_feat_layers=1, + ), + prompt_tuning_configs: Dict[str, Any] = dict(use_prompt_tuning=False), + timestep_scale: float = 1000.0, + ) -> None: + super().__init__() + + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + if num_double_stream_layers > num_layers: + raise ValueError( + f"num_double_stream_layers ({num_double_stream_layers}) cannot be greater " + f"than num_layers ({num_layers})" + ) + + self.out_channels = out_channels or in_channels + self.num_double_stream_layers = num_double_stream_layers + self.num_single_stream_layers = num_layers - num_double_stream_layers + self.instruction_feature_configs = instruction_feature_configs + self.prompt_tuning_configs = prompt_tuning_configs + self.preprocessed_instruction_feat_dim = ( + self.cal_preprocessed_instruction_feat_dim(instruction_feature_configs) + ) + + self.rope_embedder = BooguImageDoubleStreamRotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + instruction_feat_dim=self.preprocessed_instruction_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale, + ) + + self.noise_refiner = nn.ModuleList( + [ + BooguImageNoiseRefinerTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.ref_image_refiner = nn.ModuleList( + [ + BooguImageRefImgRefinerTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + BooguImageContextRefinerTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.double_stream_layers = nn.ModuleList( + [ + BooguImageDoubleStreamTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_double_stream_layers) + ] + ) + + self.single_stream_layers = nn.ModuleList( + [ + BooguImageSingleStreamTransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(self.num_single_stream_layers) + ] + ) + + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + + # Distinguish multiple reference images (supports up to 5 ref images). + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) + + self.gradient_checkpointing = False + # Attention defaults to torch SDPA ("native"); flip to "flash" via + # set_attention_backend when flash-attn is installed and wanted. + self.attention_backend = "native" + + self.initialize_weights() + + self.layers = list(self.double_stream_layers) + list(self.single_stream_layers) + + def initialize_weights(self) -> None: + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def set_attention_backend(self, backend: str) -> None: + """Select the attention implementation for every attention module. + + Args: + backend: "native" for ``F.scaled_dot_product_attention`` (the default, + no extra dependency) or "flash" for Flash Attention 2 + (``flash_attn_varlen_func``). Selecting "flash" requires the + ``flash_attn`` package to be installed. + """ + backend = backend.lower() + if backend not in ATTENTION_BACKENDS: + raise ValueError( + f"Unknown attention backend {backend!r}. " + f"Expected one of {ATTENTION_BACKENDS}." + ) + if backend == "flash" and not _FLASH_ATTN_AVAILABLE: + raise RuntimeError( + "Flash attention 2 backend requested but the `flash_attn` package " + "is not installed. Install it with `pip install flash-attn` or use " + "the 'native' backend." + ) + self.attention_backend = backend + # Processors live on the wrapping diffusers Attention modules. The single + # -stream processor is a plain object (not an nn.Module) so it isn't in + # self.modules(); reach every processor through its Attention instead. + for module in self.modules(): + if isinstance(module, Attention): + processor = getattr(module, "processor", None) + if hasattr(processor, "attention_backend"): + processor.attention_backend = backend + + def _ckpt(self, layer, *args): + """Run ``layer`` with activation checkpointing when training, else directly.""" + if torch.is_grad_enabled() and self.gradient_checkpointing: + return self._gradient_checkpointing_func(layer, *args) + return layer(*args) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ): + """Embed image patches and run the refiner blocks.""" + batch_size = len(hidden_states) + max_combined_img_len = max( + [ + img_len + sum(ref_img_len) + for img_len, ref_img_len in zip( + l_effective_img_len, l_effective_ref_img_len + ) + ] + ) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift : shift + ref_img_len, :] = ( + ref_image_hidden_states[i, shift : shift + ref_img_len, :] + + self.image_index_embedding[j] + ) + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = self._ckpt( + layer, hidden_states, padded_img_mask, noise_rotary_emb, temb + ) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros( + num_ref_images, max_ref_img_len, dtype=torch.bool + ) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros( + num_ref_images, max_ref_img_len, self.config.hidden_size + ) + batch_ref_img_rotary_emb = hidden_states.new_zeros( + num_ref_images, + max_ref_img_len, + ref_img_rotary_emb.shape[-1], + dtype=ref_img_rotary_emb.dtype, + ) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # Flatten reference images into a temporary batch. + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ( + ref_image_hidden_states[i, shift : shift + ref_img_len] + ) + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[ + i, shift : shift + ref_img_len + ] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = self._ckpt( + layer, + batch_ref_image_hidden_states, + batch_ref_img_mask, + batch_ref_img_rotary_emb, + batch_temb, + ) + + # Restore reference-image sequence layout. + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift : shift + ref_img_len] = ( + batch_ref_image_hidden_states[idx, :ref_img_len] + ) + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros( + batch_size, max_combined_img_len, self.config.hidden_size + ) + for i, (ref_img_len, img_len) in enumerate( + zip(l_effective_ref_img_len, l_effective_img_len) + ): + combined_img_hidden_states[i, : sum(ref_img_len)] = ref_image_hidden_states[ + i, : sum(ref_img_len) + ] + combined_img_hidden_states[ + i, sum(ref_img_len) : sum(ref_img_len) + img_len + ] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + """Flatten patch tokens and pad to batched sequences.""" + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + + if ref_image_hidden_states is not None: + ref_img_sizes = [ + [(img.size(1), img.size(2)) for img in imgs] + if imgs is not None + else None + for imgs in ref_image_hidden_states + ] + l_effective_ref_img_len = [ + [ + (ref_img_size[0] // p) * (ref_img_size[1] // p) + for ref_img_size in _ref_img_sizes + ] + if _ref_img_sizes is not None + else [0] + for _ref_img_sizes in ref_img_sizes + ] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max( + [sum(ref_img_len) for ref_img_len in l_effective_ref_img_len] + ) + max_img_len = max(l_effective_img_len) + + # Reference-image patch embeddings. + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + ref_img = rearrange( + ref_img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p + ) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # Noise-image patch embeddings. + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + img = rearrange(img, "c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=p, p2=p) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros( + batch_size, + max_ref_img_len, + flat_hidden_states[0].shape[-1], + device=device, + dtype=flat_hidden_states[0].dtype, + ) + padded_ref_img_mask = torch.zeros( + batch_size, max_ref_img_len, dtype=torch.bool, device=device + ) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, : sum(l_effective_ref_img_len[i])] = ( + flat_ref_img_hidden_states[i] + ) + padded_ref_img_mask[i, : sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros( + batch_size, + max_img_len, + flat_hidden_states[0].shape[-1], + device=device, + dtype=flat_hidden_states[0].dtype, + ) + padded_img_mask = torch.zeros( + batch_size, max_img_len, dtype=torch.bool, device=device + ) + for i in range(batch_size): + padded_hidden_states[i, : l_effective_img_len[i]] = flat_hidden_states[i] + padded_img_mask[i, : l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def cal_preprocessed_instruction_feat_dim( + self, instruction_feature_configs: Dict[str, Any] + ): + num_instruction_feat_layers = max( + instruction_feature_configs.get("num_instruction_feat_layers", 1), 1 + ) + instruction_feat_dim = instruction_feature_configs.get( + "instruction_feat_dim", 4096 + ) + reduce_type = instruction_feature_configs.get("reduce_type", "concat") + if "cat" in reduce_type.lower(): + return num_instruction_feat_layers * instruction_feat_dim + elif "mean" in reduce_type.lower(): + return instruction_feat_dim + else: + raise ValueError(f"Invalid reduce_type: {reduce_type}") + + def preprocess_instruction_hidden_states( + self, raw_instruction_hidden_states, instruction_feature_configs: Dict[str, Any] + ): + num_instruction_feat_layers = max( + instruction_feature_configs.get("num_instruction_feat_layers", 1), 1 + ) + reduce_type = instruction_feature_configs.get("reduce_type", "concat") + + instruction_hidden_states = None + if isinstance(raw_instruction_hidden_states, torch.Tensor): + instruction_hidden_states = raw_instruction_hidden_states + elif isinstance(raw_instruction_hidden_states, (list, tuple)): + assert len(raw_instruction_hidden_states) == num_instruction_feat_layers + if "cat" in reduce_type.lower(): + instruction_hidden_states = torch.cat( + raw_instruction_hidden_states, dim=-1 + ) + elif "mean" in reduce_type.lower(): + instruction_hidden_states = torch.mean( + torch.stack(raw_instruction_hidden_states), dim=0 + ) + else: + raise ValueError(f"Invalid reduce_type: {reduce_type}") + else: + raise ValueError( + "Invalid type of raw_instruction_hidden_states, expected torch.Tensor " + f"or list, but got {type(raw_instruction_hidden_states)}" + ) + + assert ( + self.preprocessed_instruction_feat_dim + == instruction_hidden_states.shape[-1] + ) + + return instruction_hidden_states + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + instruction_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + instruction_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """context/refiner -> double-stream -> fusion -> single-stream -> projection.""" + instruction_hidden_states = self.preprocess_instruction_hidden_states( + instruction_hidden_states, self.instruction_feature_configs + ) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + elif attention_kwargs is not None and attention_kwargs.get("scale") is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT " + "backend is ineffective." + ) + + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + # Timestep and instruction embedding. + temb, instruction_hidden_states = self.time_caption_embed( + timestep, instruction_hidden_states, hidden_states[0].dtype + ) + + # Flatten and pad token sequences. + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + # Build rotary embeddings and sequence lengths. + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + combined_img_rotary_emb, + combined_img_seq_lengths, + ) = self.rope_embedder( + freqs_cis, + instruction_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # Context refinement (non-modulated, so no temb). + for layer in self.context_refiner: + instruction_hidden_states = self._ckpt( + layer, + instruction_hidden_states, + instruction_attention_mask, + context_rotary_emb, + ) + + # Image patch embedding and refinement. + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + instruct_hidden_states = instruction_hidden_states + img_hidden_states = combined_img_hidden_states + + # Joint mask for [instruct + image]. + max_seq_len = max(seq_lengths) + joint_attention_mask = hidden_states.new_zeros( + batch_size, max_seq_len, dtype=torch.bool + ) + for i, seq_len in enumerate(seq_lengths): + joint_attention_mask[i, :seq_len] = True + + # Double-stream stage. + if self.num_double_stream_layers > 0: + max_img_len = max(combined_img_seq_lengths) + img_attention_mask = hidden_states.new_zeros( + batch_size, max_img_len, dtype=torch.bool + ) + for i, img_seq_len in enumerate(combined_img_seq_lengths): + img_attention_mask[i, :img_seq_len] = True + + for layer in self.double_stream_layers: + img_hidden_states, instruct_hidden_states = self._ckpt( + layer, + img_hidden_states, + instruct_hidden_states, + img_attention_mask, + joint_attention_mask, + combined_img_rotary_emb, + rotary_emb, + temb, + encoder_seq_lengths, + seq_lengths, + ) + + # Fuse streams to joint sequence. + joint_hidden_states = hidden_states.new_zeros( + batch_size, max(seq_lengths), self.config.hidden_size + ) + for i, (encoder_seq_len, seq_len) in enumerate( + zip(encoder_seq_lengths, seq_lengths) + ): + joint_hidden_states[i, :encoder_seq_len] = instruct_hidden_states[ + i, :encoder_seq_len + ] + joint_hidden_states[i, encoder_seq_len:seq_len] = img_hidden_states[ + i, : seq_len - encoder_seq_len + ] + + hidden_states = joint_hidden_states + + # Single-stream stage. + for layer in self.single_stream_layers: + hidden_states = self._ckpt( + layer, hidden_states, joint_attention_mask, rotary_emb, temb + ) + + # Output projection. + hidden_states = self.norm_out(hidden_states, temb) + + # Reshape back to image format. + p = self.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate( + zip(img_sizes, l_effective_img_len, seq_lengths) + ): + height, width = img_size + img_tokens = hidden_states[i][seq_len - img_len : seq_len] + img_output = rearrange( + img_tokens, + "(h w) (p1 p2 c) -> c (h p1) (w p2)", + h=height // p, + w=width // p, + p1=p, + p2=p, + ) + output.append(img_output) + + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c34866a6a97d6f7af0cb468dbad154fb5ebfb0b1 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/__init__.py @@ -0,0 +1,2 @@ +from .chroma_model import ChromaModel +from .chroma_radiance_model import ChromaRadianceModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/chroma_model.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/chroma_model.py new file mode 100644 index 0000000000000000000000000000000000000000..236d9508bb66a816ecf19ffb89982cae8a7c8ebc --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -0,0 +1,463 @@ +import os +from typing import TYPE_CHECKING + +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import ChromaPipeline, prepare_latent_image_ids +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from .src.model import Chroma, chroma_params +from safetensors.torch import load_file, save_file +from toolkit.metadata import get_meta_for_safetensors +import huggingface_hub + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.attention_head_dim = 128 + self.guidance_embeds = True + self.in_channels = 64 + self.joint_attention_dim = 4096 + self.num_attention_heads = 24 + self.num_layers = 19 + self.num_single_layers = 38 + self.patch_size = 1 + +class FakeCLIP(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.bfloat16 + self.device = 'cuda' + self.text_model = None + self.tokenizer = None + self.model_max_length = 77 + + def forward(self, *args, **kwargs): + return torch.zeros(1, 1, 1).to(self.device) + + +class ChromaModel(BaseModel): + arch = "chroma" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['Chroma'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 32 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + if model_path == "lodestones/Chroma": + print("Looking for latest Chroma checkpoint") + # get the latest checkpoint + files_list = huggingface_hub.list_repo_files(model_path) + print(files_list) + latest_version = 28 # current latest version at time of writing + while True: + if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list: + latest_version -= 1 + break + else: + latest_version += 1 + print(f"Using latest Chroma version: v{latest_version}") + + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"chroma-unlocked-v{latest_version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma/v"): + # get the version number + version = model_path.split("/")[-1].split("v")[-1] + print(f"Using Chroma version: v{version}") + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id='lodestones/Chroma', + filename=f"chroma-unlocked-v{version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma1-"): + # will have a file in the repo that is Chroma1-whatever.safetensors + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"{model_path.split('/')[-1]}.safetensors", + ) + else: + # check if the model path is a local file + if os.path.exists(model_path): + print(f"Using local model: {model_path}") + else: + raise ValueError(f"Model path {model_path} does not exist") + + # extras_path = 'black-forest-labs/FLUX.1-schnell' + # schnell model is gated now, use flex instead + extras_path = 'ostris/Flex.1-alpha' + + self.print_and_status_update("Loading transformer") + + chroma_state_dict = load_file(model_path, 'cpu') + + # determine number of double and single blocks + double_blocks = 0 + single_blocks = 0 + for key in chroma_state_dict.keys(): + if "double_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > double_blocks: + double_blocks = block_num + elif "single_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > single_blocks: + single_blocks = block_num + print(f"Double Blocks: {double_blocks}") + print(f"Single Blocks: {single_blocks}") + + chroma_params.depth = double_blocks + chroma_params.depth_single_blocks = single_blocks + transformer = Chroma(chroma_params) + + # add dtype, not sure why it doesnt have it + transformer.dtype = dtype + # load the state dict into the model + transformer.load_state_dict(chroma_state_dict) + + transformer.to(self.quantize_device, dtype=dtype) + + transformer.config = FakeConfig() + transformer.config.num_layers = double_blocks + transformer.config.num_single_layers = single_blocks + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + # self.print_and_status_update("Loading CLIP") + text_encoder = FakeCLIP() + tokenizer = FakeCLIP() + text_encoder.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ChromaModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=dtype + ) + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: ChromaPipeline = ChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ChromaModel.get_train_scheduler() + pipeline = ChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer) + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attn_mask=conditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = prepare_latent_image_ids( + bs, + h, + w, + patch_size=2 + ).to(device=self.device_torch) + + # img_ids = torch.zeros(h // 2, w // 2, 3) + # img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + # img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + # img_ids = repeat(img_ids, "h w c -> b (h w) c", + # b=bs).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32) + guidance = guidance.expand(latent_model_input_packed.shape[0]) + + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + img=latent_model_input_packed.to( + self.device_torch, cast_dtype + ), + img_ids=img_ids, + txt=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + txt_ids=txt_ids, + txt_mask=text_embeddings.attention_mask.to( + self.device_torch, cast_dtype + ), + timesteps=timestep / 1000, + guidance=guidance + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=self.vae.config.latent_channels + ) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + max_length = 512 + + device = self.text_encoder[1].device + dtype = self.text_encoder[1].dtype + + # T5 + text_inputs = self.tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + prompt_attention_mask = text_inputs["attention_mask"] + + pe = PromptEmbeds( + prompt_embeds + ) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.final_layer.linear.weight.requires_grad + + def get_te_has_grad(self): + # return from a weight if it has grad + return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: Chroma = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to('cpu', dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name='chroma') + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "chroma" diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py new file mode 100644 index 0000000000000000000000000000000000000000..333600e8f1f11b698937e46bbb7113681e2ffb36 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py @@ -0,0 +1,445 @@ +import os +from typing import TYPE_CHECKING + +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import ChromaPipeline, prepare_latent_image_ids +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from .src.radiance import Chroma, chroma_params +from safetensors.torch import load_file, save_file +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.FakeVAE import FakeVAE +import huggingface_hub + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.attention_head_dim = 128 + self.guidance_embeds = True + self.in_channels = 64 + self.joint_attention_dim = 4096 + self.num_attention_heads = 24 + self.num_layers = 19 + self.num_single_layers = 38 + self.patch_size = 1 + +class FakeCLIP(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.bfloat16 + self.device = 'cuda' + self.text_model = None + self.tokenizer = None + self.model_max_length = 77 + + def forward(self, *args, **kwargs): + return torch.zeros(1, 1, 1).to(self.device) + + +class ChromaRadianceModel(BaseModel): + arch = "chroma_radiance" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['Chroma'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 32 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + if model_path == "lodestones/Chroma": + print("Looking for latest Chroma checkpoint") + # get the latest checkpoint + files_list = huggingface_hub.list_repo_files(model_path) + print(files_list) + latest_version = 28 # current latest version at time of writing + while True: + if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list: + latest_version -= 1 + break + else: + latest_version += 1 + print(f"Using latest Chroma version: v{latest_version}") + + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"chroma-unlocked-v{latest_version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma/v"): + # get the version number + version = model_path.split("/")[-1].split("v")[-1] + print(f"Using Chroma version: v{version}") + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id='lodestones/Chroma', + filename=f"chroma-unlocked-v{version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma1-"): + # will have a file in the repo that is Chroma1-whatever.safetensors + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"{model_path.split('/')[-1]}.safetensors", + ) + + else: + # check if the model path is a local file + if os.path.exists(model_path): + print(f"Using local model: {model_path}") + else: + raise ValueError(f"Model path {model_path} does not exist") + + # extras_path = 'black-forest-labs/FLUX.1-schnell' + # schnell model is gated now, use flex instead + extras_path = 'ostris/Flex.1-alpha' + + self.print_and_status_update("Loading transformer") + + if model_path.endswith('.pth') or model_path.endswith('.pt'): + chroma_state_dict = torch.load(model_path, map_location='cpu', weights_only=True) + else: + chroma_state_dict = load_file(model_path, 'cpu') + + # determine number of double and single blocks + double_blocks = 0 + single_blocks = 0 + for key in chroma_state_dict.keys(): + if "double_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > double_blocks: + double_blocks = block_num + elif "single_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > single_blocks: + single_blocks = block_num + print(f"Double Blocks: {double_blocks}") + print(f"Single Blocks: {single_blocks}") + + chroma_params.depth = double_blocks + chroma_params.depth_single_blocks = single_blocks + transformer = Chroma(chroma_params) + + # add dtype, not sure why it doesnt have it + transformer.dtype = dtype + # load the state dict into the model + transformer.load_state_dict(chroma_state_dict) + + transformer.to(self.quantize_device, dtype=dtype) + + transformer.config = FakeConfig() + transformer.config.num_layers = double_blocks + transformer.config.num_single_layers = single_blocks + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + # self.print_and_status_update("Loading CLIP") + text_encoder = FakeCLIP() + tokenizer = FakeCLIP() + text_encoder.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ChromaRadianceModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + # vae = AutoencoderKL.from_pretrained( + # extras_path, + # subfolder="vae", + # torch_dtype=dtype + # ) + vae = FakeVAE() + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: ChromaPipeline = ChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + is_radiance=True, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ChromaRadianceModel.get_train_scheduler() + pipeline = ChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + is_radiance=True, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attn_mask=conditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + + img_ids = prepare_latent_image_ids( + bs, h, w, patch_size=16 + ).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32) + guidance = guidance.expand(bs) + + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + img=latent_model_input.to( + self.device_torch, cast_dtype + ), + img_ids=img_ids, + txt=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + txt_ids=txt_ids, + txt_mask=text_embeddings.attention_mask.to( + self.device_torch, cast_dtype + ), + timesteps=timestep / 1000, + guidance=guidance + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + max_length = 512 + + device = self.text_encoder[1].device + dtype = self.text_encoder[1].dtype + + # T5 + text_inputs = self.tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + prompt_attention_mask = text_inputs["attention_mask"] + + pe = PromptEmbeds( + prompt_embeds + ) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + def get_te_has_grad(self): + # return from a weight if it has grad + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: Chroma = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to('cpu', dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name='chroma') + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "chroma_radiance" diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5a76a71312ad0be31b3ffe86b7fe4689d9ac698a --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/pipeline.py @@ -0,0 +1,328 @@ +from typing import Union, List, Optional, Dict, Any, Callable + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.utils import is_torch_xla_available +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +def prepare_latent_image_ids(batch_size, height, width, patch_size=2, max_offset=0): + """ + Generates positional embeddings for a latent image. + + Args: + batch_size (int): The number of images in the batch. + height (int): The height of the image. + width (int): The width of the image. + patch_size (int, optional): The size of the patches. Defaults to 2. + max_offset (int, optional): The maximum random offset to apply. Defaults to 0. + + Returns: + torch.Tensor: A tensor containing the positional embeddings. + """ + # the random pos embedding helps generalize to larger res without training at large res + # pos embedding for rope, 2d pos embedding, corner embedding and not center based + latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3) + + # Add positional encodings + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :] + ) + + # Add random offset if specified + if max_offset > 0: + offset_y = torch.randint(0, max_offset + 1, (1,)).item() + offset_x = torch.randint(0, max_offset + 1, (1,)).item() + latent_image_ids[..., 1] += offset_y + latent_image_ids[..., 2] += offset_x + + + ( + latent_image_id_height, + latent_image_id_width, + latent_image_id_channels, + ) = latent_image_ids.shape + + # Reshape for batch + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, + latent_image_id_height * latent_image_id_width, + latent_image_id_channels, + ) + + return latent_image_ids + + +class ChromaPipeline(FluxPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + image_encoder = None, + feature_extractor = None, + is_radiance: bool = False, + ): + super().__init__( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.is_radiance = is_radiance + self.vae_scale_factor = 8 if not is_radiance else 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = prepare_latent_image_ids( + batch_size, + height, + width, + patch_size=2 if not self.is_radiance else 16 + ).to(device=device, dtype=dtype) + # latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if not self.is_radiance: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + # latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = prepare_latent_image_ids( + batch_size, + height, + width, + patch_size=2 if not self.is_radiance else 16 + ).to(device=device, dtype=dtype) + + return latents, latent_image_ids + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attn_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attn_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[ + int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if isinstance(device, str): + device = torch.device(device) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16) + if guidance_scale > 1.00001: + negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16) + + # 4. Prepare latent variables + num_channels_latents = 64 // 4 + if self.is_radiance: + num_channels_latents = 3 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # extend img ids to match batch size + # latent_image_ids = latent_image_ids.unsqueeze(0) + # latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + guidance = torch.full([1], 0, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + + noise_pred_text = self.transformer( + img=latents, + img_ids=latent_image_ids, + txt=prompt_embeds, + txt_ids=text_ids, + txt_mask=prompt_attn_mask, # todo add this + timesteps=timestep / 1000, + guidance=guidance + ) + + if guidance_scale > 1.00001: + noise_pred_uncond = self.transformer( + img=latents, + img_ids=latent_image_ids, + txt=negative_prompt_embeds, + txt_ids=negative_text_ids, + txt_mask=negative_prompt_attn_mask, # todo add this + timesteps=timestep / 1000, + guidance=guidance + ) + + noise_pred = noise_pred_uncond + self.guidance_scale * \ + (noise_pred_text - noise_pred_uncond) + + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if not self.is_radiance: + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + \ + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess( + image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f345402952e67596c3cf80b0b8580228ee8e09a7 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/__init__.py @@ -0,0 +1 @@ +# This is taken and slightly modified from https://github.com/lodestone-rock/flow/tree/master/src/models/chroma \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/layers.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..6a668323ccc2e937187da07c1a574d0daf735b67 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/layers.py @@ -0,0 +1,720 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn +import torch.nn.functional as F + +from .math import attention, rope +from functools import lru_cache + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, use_compiled: bool = False): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + self.use_compiled = use_compiled + + def _forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + def forward(self, x: Tensor): + return F.rms_norm(x, self.scale.shape, weight=self.scale, eps=1e-6) + # if self.use_compiled: + # return torch.compile(self._forward)(x) + # else: + # return self._forward(x) + + +def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks): + """ + Distributes slices of the tensor into the block_dict as ModulationOut objects. + + Args: + tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. + """ + batch_size, vectors, dim = tensor.shape + + block_dict = {} + + # HARD CODED VALUES! lookup table for the generated vectors + # TODO: move this into chroma config! + # Add 38 single mod blocks + for i in range(depth_single_blocks): + key = f"single_blocks.{i}.modulation.lin" + block_dict[key] = None + + # Add 19 image double blocks + for i in range(depth_double_blocks): + key = f"double_blocks.{i}.img_mod.lin" + block_dict[key] = None + + # Add 19 text double blocks + for i in range(depth_double_blocks): + key = f"double_blocks.{i}.txt_mod.lin" + block_dict[key] = None + + # Add the final layer + block_dict["final_layer.adaLN_modulation.1"] = None + # 6.2b version + # block_dict["lite_double_blocks.4.img_mod.lin"] = None + # block_dict["lite_double_blocks.4.txt_mod.lin"] = None + + idx = 0 # Index to keep track of the vector slices + + for key in block_dict.keys(): + if "single_blocks" in key: + # Single block: 1 ModulationOut + block_dict[key] = ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + idx += 3 # Advance by 3 vectors + + elif "img_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "txt_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "final_layer" in key: + # Final layer: 1 ModulationOut + block_dict[key] = [ + tensor[:, idx : idx + 1, :], + tensor[:, idx + 1 : idx + 2, :], + ] + idx += 2 # Advance by 3 vectors + + return block_dict + + + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__(self, in_channels, hidden_size_input, max_freqs): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + nn.Linear(in_channels + max_freqs**2, hidden_size_input) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size, device, dtype): + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs): + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + # Store the original dtype to cast back to at the end. + original_dtype = inputs.dtype + # Force all operations within this module to run in fp32. + with torch.autocast("cuda", enabled=False): + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + inputs = inputs.float() + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, torch.float32) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat([inputs, dct], dim=-1) + + # Project the combined tensor to the target hidden size. + inputs = self.embedder.float()(inputs) + + return inputs.to(original_dtype) + + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, use_compiled): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = nn.Linear(hidden_size_s, total_params) + self.norm = RMSNorm(hidden_size_x, use_compiled) + self.mlp_ratio = mlp_ratio + # nn.init.zeros_(self.param_generator.weight) + # nn.init.zeros_(self.param_generator.bias) + + + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + x = x + res_x + return x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, use_compiled): + super().__init__() + self.norm = RMSNorm(hidden_size, use_compiled=use_compiled) + self.linear = nn.Linear(hidden_size, out_channels) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.norm(x) + x = self.linear(x) + return x + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size, out_channels, use_compiled): + super().__init__() + self.norm = RMSNorm(hidden_size, use_compiled=use_compiled) + + # replace nn.Linear with nn.Conv2d since linear is just pointwise conv + self.conv = nn.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1 + ) + nn.init.zeros_(self.conv.weight) + nn.init.zeros_(self.conv.bias) + + def forward(self, x): + # shape: [N, C, H, W] ! + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So, we permute the dimensions to make the channel dimension the last one. + x_permuted = x.permute(0, 2, 3, 1) # Shape becomes [N, H, W, C] + + # Apply normalization on the feature/channel dimension + x_norm = self.norm(x_permuted) + + # Permute back to the original dimension order for the convolution + x_norm_permuted = x_norm.permute(0, 3, 1, 2) # Shape becomes [N, C, H, W] + + # Apply the 3x3 convolution + x = self.conv(x_norm_permuted) + return x + + +class Approximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4): + super().__init__() + self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True) + self.layers = nn.ModuleList( + [MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)] + ) + self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)]) + self.out_proj = nn.Linear(hidden_dim, out_dim) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def forward(self, x: Tensor) -> Tensor: + x = x.to(self.in_proj.weight.dtype) + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int, use_compiled: bool = False): + super().__init__() + self.query_norm = RMSNorm(dim, use_compiled=use_compiled) + self.key_norm = RMSNorm(dim, use_compiled=use_compiled) + self.use_compiled = use_compiled + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + use_compiled: bool = False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim, use_compiled=use_compiled) + self.proj = nn.Linear(dim, dim) + self.use_compiled = use_compiled + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +def _modulation_shift_scale_fn(x, scale, shift): + return (1 + scale) * x + shift + + +def _modulation_gate_fn(x, gate, gate_params): + return x + gate * gate_params + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool = False, + use_compiled: bool = False, + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_compiled=use_compiled, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_compiled=use_compiled, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + self.use_compiled = use_compiled + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + if self.use_compiled: + return torch.compile(_modulation_shift_scale_fn)(x, scale, shift) + else: + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + if self.use_compiled: + return torch.compile(_modulation_gate_fn)(x, gate, gate_params) + else: + return _modulation_gate_fn(x, gate, gate_params) + + def forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + distill_vec: list[ModulationOut], + mask: Tensor, + ) -> tuple[Tensor, Tensor]: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec + + # prepare image for attention + img_modulated = self.img_norm1(img) + # replaced with compiled fn + # img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = self.modulation_shift_scale_fn( + img_modulated, img_mod1.scale, img_mod1.shift + ) + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + # replaced with compiled fn + # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = self.modulation_shift_scale_fn( + txt_modulated, txt_mod1.scale, txt_mod1.shift + ) + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe, mask=mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + # replaced with compiled fn + # img = img + img_mod1.gate * self.img_attn.proj(img_attn) + # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn)) + img = self.modulation_gate_fn( + img, + img_mod2.gate, + self.img_mlp( + self.modulation_shift_scale_fn( + self.img_norm2(img), img_mod2.scale, img_mod2.shift + ) + ), + ) + + # calculate the txt bloks + # replaced with compiled fn + # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn)) + txt = self.modulation_gate_fn( + txt, + txt_mod2.gate, + self.txt_mlp( + self.modulation_shift_scale_fn( + self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift + ) + ), + ) + + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + use_compiled: bool = False, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim, use_compiled=use_compiled) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.use_compiled = use_compiled + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + if self.use_compiled: + return torch.compile(_modulation_shift_scale_fn)(x, scale, shift) + else: + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + if self.use_compiled: + return torch.compile(_modulation_gate_fn)(x, gate, gate_params) + else: + return _modulation_gate_fn(x, gate, gate_params) + + def forward( + self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor + ) -> Tensor: + mod = distill_vec + # replaced with compiled fn + # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift) + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, mask=mask) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + # replaced with compiled fn + # return x + mod.gate * output + return self.modulation_gate_fn(x, mod.gate, output) + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + use_compiled: bool = False, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.use_compiled = use_compiled + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + if self.use_compiled: + return torch.compile(_modulation_shift_scale_fn)(x, scale, shift) + else: + return _modulation_shift_scale_fn(x, scale, shift) + + def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor: + shift, scale = distill_vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + # replaced with compiled fn + # x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.modulation_shift_scale_fn( + self.norm_final(x), scale[:, None, :], shift[:, None, :] + ) + x = self.linear(x) + return x diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/math.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/math.py new file mode 100644 index 0000000000000000000000000000000000000000..31205341ca6fbe44fbd1ccf49e27c95a535faccb --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/math.py @@ -0,0 +1,51 @@ +import torch +from einops import rearrange +from torch import Tensor + +# Flash-Attention 2 (optional) +try: + from flash_attn.flash_attn_interface import flash_attn_func # type: ignore + _HAS_FLASH = True +except (ImportError, ModuleNotFoundError): + _HAS_FLASH = False + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + # mask should have shape [B, H, L, D] + if _HAS_FLASH and mask is None and q.is_cuda: + x = flash_attn_func( + rearrange(q, "B H L D -> B L H D").contiguous(), + rearrange(k, "B H L D -> B L H D").contiguous(), + rearrange(v, "B H L D -> B L H D").contiguous(), + dropout_p=0.0, + softmax_scale=None, + causal=False, + ) + x = rearrange(x, "B L H D -> B H L D") + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + + x = rearrange(x, "B H L D -> B L (H D)") + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/model.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ebdf69d939bcbfcfe391c1d04aea1b17c3c5a7a8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/model.py @@ -0,0 +1,282 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + SingleStreamBlock, + timestep_embedding, + Approximator, + distribute_modulations, +) + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=64, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + super().__init__() + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + self.gradient_checkpointing = False + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer( + self.hidden_size, + 1, + self.out_channels, + use_compiled=params._use_compiled, + ) + + # TODO: move this hardcoded value to config + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + self.approximator_in_dim = params.approximator_in_dim + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + guidance: Tensor, + attn_padding: int = 1, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + txt = self.txt_in(txt) + + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + distill_timestep = timestep_embedding(timesteps, 16) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, 16) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, 32) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = ( + torch.cat([distill_timestep, distil_guidance], dim=1) + .unsqueeze(1) + .repeat(1, self.mod_index_length, 1) + ) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + # compute mask + # assume max seq length from the batched input + + max_len = txt.shape[1] + + # mask + with torch.no_grad(): + txt_mask_w_padding = modify_mask_to_attend_padding( + txt_mask, max_len, attn_padding + ) + txt_img_mask = torch.cat( + [ + txt_mask_w_padding, + torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + ], + dim=1, + ) + txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() + txt_img_mask = ( + txt_img_mask[None, None, ...] + .repeat(txt.shape[0], self.num_heads, 1, 1) + .int() + .bool() + ) + # txt_mask_w_padding[txt_mask_w_padding==False] = True + + for i, block in enumerate(self.double_blocks): + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img, txt = ckpt.checkpoint( + block, img, txt, pe, double_mod, txt_img_mask + ) + else: + img, txt = block( + img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask + ) + + img = torch.cat((txt, img), 1) + for i, block in enumerate(self.single_blocks): + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) + else: + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + img = img[:, txt.shape[1] :, ...] + final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + img = self.final_layer( + img, distill_vec=final_mod + ) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/radiance.py b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/radiance.py new file mode 100644 index 0000000000000000000000000000000000000000..d328f261e6efb76d8bfe6bfc5bece3f8bc14bdec --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/chroma/src/radiance.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + SingleStreamBlock, + timestep_embedding, + Approximator, + distribute_modulations, + NerfEmbedder, + NerfFinalLayer, + NerfFinalLayerConv, + NerfGLUBlock +) + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=3, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + patch_size=16, + nerf_hidden_size=64, + nerf_mlp_ratio=4, + nerf_depth=4, + nerf_max_freqs=8, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + super().__init__() + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + self.gradient_checkpointing = False + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + # self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + # patchify ops + self.img_in_patch = nn.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True + ) + nn.init.zeros_(self.img_in_patch.weight) + nn.init.zeros_(self.img_in_patch.bias) + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # self.final_layer = LastLayer( + # self.hidden_size, + # 1, + # self.out_channels, + # use_compiled=params._use_compiled, + # ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + use_compiled=params._use_compiled + ) for _ in range(params.nerf_depth) + ]) + # self.nerf_final_layer = NerfFinalLayer( + # params.nerf_hidden_size, + # out_channels=params.in_channels, + # use_compiled=params._use_compiled + # ) + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + use_compiled=params._use_compiled + ) + # TODO: move this hardcoded value to config + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + self.approximator_in_dim = params.approximator_in_dim + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + guidance: Tensor, + attn_padding: int = 1, + ) -> Tensor: + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if txt.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + B, C, H, W = img.shape + + # gemini gogogo idk how to unfold and pack the patch properly :P + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + # partchify ops + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + num_patches = img.shape[2] * img.shape[3] + # flatten into a sequence for the transformer. + img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + txt = self.txt_in(txt) + + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim//4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim//4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim//2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = ( + torch.cat([distill_timestep, distil_guidance], dim=1) + .unsqueeze(1) + .repeat(1, self.mod_index_length, 1) + ) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + # compute mask + # assume max seq length from the batched input + + max_len = txt.shape[1] + + # mask + with torch.no_grad(): + txt_mask_w_padding = modify_mask_to_attend_padding( + txt_mask, max_len, attn_padding + ) + txt_img_mask = torch.cat( + [ + txt_mask_w_padding, + torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + ], + dim=1, + ) + txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() + txt_img_mask = ( + txt_img_mask[None, None, ...] + .repeat(txt.shape[0], self.num_heads, 1, 1) + .int() + .bool() + ) + # txt_mask_w_padding[txt_mask_w_padding==False] = True + + for i, block in enumerate(self.double_blocks): + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + # just in case in different GPU for simple pipeline parallel + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img, txt = ckpt.checkpoint( + block, img, txt, pe, double_mod, txt_img_mask + ) + else: + img, txt = block( + img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask + ) + + img = torch.cat((txt, img), 1) + for i, block in enumerate(self.single_blocks): + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) + else: + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + img = img[:, txt.shape[1] :, ...] + + # final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + # img = self.final_layer( + # img, distill_vec=final_mod + # ) # (N, T, patch_size ** 2 * out_channels) + + # aliasing + nerf_hidden = img + # reshape for per-patch processing + nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # pass through the dynamic MLP blocks (the NeRF) + for i, block in enumerate(self.nerf_blocks): + if self.training: + img_dct = ckpt.checkpoint(block, img_dct, nerf_hidden) + else: + img_dct = block(img_dct, nerf_hidden) + + # final projection to get the output pixel values + # img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C] + img_dct = self.nerf_final_layer_conv.norm(img_dct) + + # gemini gogogo idk how to fold this properly :P + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=self.params.patch_size, + stride=self.params.patch_size + ) # [B, Hidden, H, W] + img_dct = self.nerf_final_layer_conv.conv(img_dct) + + return img_dct \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbec312546470dd45ef430034aa9b66ce64f09cc --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/__init__.py @@ -0,0 +1 @@ +from .ernie_image import ErnieImageModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/ernie_image.py b/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/ernie_image.py new file mode 100644 index 0000000000000000000000000000000000000000..dba1204662dd56e28999756ecee9b03090716aa1 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/ernie_image.py @@ -0,0 +1,392 @@ +import os +from typing import List, Optional + +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager + +from transformers import AutoTokenizer, AutoModel + +try: + from diffusers import ErnieImagePipeline, AutoencoderKLFlux2 + from .transformer import ErnieImageTransformer2DModel +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "shift_terminal": None, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} + + +class ErnieImageModel(BaseModel): + arch = "ernie_image" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["ErnieImageTransformer2DModel"] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 * 2 # 16 for the VAE, 2 for patch size + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading ErnieImage model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = ErnieImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + transformer.x_embedder, + ], + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = AutoModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKLFlux2.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ).to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ErnieImageModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: ErnieImagePipeline = ErnieImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ErnieImageModel.get_train_scheduler() + + pipeline: ErnieImagePipeline = ErnieImagePipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if self.vae.device == torch.device("cpu"): + self.vae.to(self.device_torch) + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + self.vae.eval() + self.vae.requires_grad_(False) + + image = image_list + if isinstance(image, list): + image = torch.stack(image, dim=0) + + image = image.to(device, dtype=dtype) + + latents = self.vae.encode(image).latent_dist.sample() + + latents = self.pipeline._patchify_latents(latents) + + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to( + device=latents.device, dtype=latents.dtype + ) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to( + device=latents.device, dtype=latents.dtype + ) + latents = (latents - bn_mean) / bn_std + + return latents + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if self.vae.device == torch.device("cpu"): + self.vae.to(self.device_torch) + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latents = latents.to(device, dtype=dtype) + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) + latents = latents * bn_std + bn_mean + + # Unpatchify + latents = self.pipeline._unpatchify_latents(latents) + + # Decode + images = self.vae.decode(latents, return_dict=False)[0] + return images + + def generate_single_image( + self, + pipeline: ErnieImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: AdvancedPromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + text_bth, text_lens = self.pipeline._pad_text( + text_hiddens=text_embeddings.text_embeds, + device=self.device_torch, + dtype=self.vae.dtype, + text_in_dim=self.pipeline.transformer.config.text_in_dim, + ) + + pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + text_bth=text_bth, + text_lens=text_lens, + return_dict=False, + )[0] + + return pred + + def get_prompt_embeds(self, prompt: str) -> AdvancedPromptEmbeds: + if self.pipeline.text_encoder.device == torch.device("cpu"): + self.pipeline.text_encoder.to(self.device_torch) + + if isinstance(prompt, str): + prompt = [prompt] + + text_hiddens = [] + + for p in prompt: + ids = self.pipeline.tokenizer( + p, + add_special_tokens=True, + truncation=True, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if self.pipeline.tokenizer.bos_token_id is not None: + ids = [self.pipeline.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=self.device_torch) + outputs = self.pipeline.text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + # Use second to last hidden state (matches training) + hidden = outputs.hidden_states[-2][0] # [T, H] + + text_hiddens.append(hidden) + + pe = AdvancedPromptEmbeds(text_embeds=text_hiddens) + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer: ErnieImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return self.arch + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/transformer.py b/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d27efee61dfd351cb32c2e2d067ffa5dba0bd01 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ernie_image/transformer.py @@ -0,0 +1,435 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Transformer2DModel for HuggingFace Diffusers. +This is patched for AI Toolkit to handle batch sizes larger than 1. +TODO remove this and use official implementation once a fix is released: +""" + +import inspect +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention import AttentionModuleMixin +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ErnieImageTransformer2DModelOutput(BaseOutput): + sample: torch.Tensor + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + return out.float() + + +class ErnieImageEmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(2) # [B, S, 1, head_dim//2] + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + + +class ErnieImagePatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + batch_size, dim, height, width = x.shape + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() + + +class ErnieImageSingleStreamAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ErnieImageSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + freqs_cis: torch.Tensor | None = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE: same rotate_half logic as Megatron _apply_rotary_pos_emb_bshd (rotary_interleaved=False) + # x_in: [B, S, heads, head_dim], freqs_cis: [B, S, 1, head_dim] with angles [θ0,θ0,θ1,θ1,...] + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = torch.cos(freqs_cis).to(x.dtype) + sin_ = torch.sin(freqs_cis).to(x.dtype) + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + output = attn.to_out[0](hidden_states) + + return output + + +class ErnieImageAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = ErnieImageSingleStreamAttnProcessor + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: str = "rms_norm", + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + if qk_norm == "layer_norm": + self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm": + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class ErnieImageFeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int): + super().__init__() + # Separate gate and up projections (matches converted weights) + self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + + +class ErnieImageSharedAdaLNBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True + ): + super().__init__() + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) + self.self_attention = ErnieImageAttention( + query_dim=hidden_size, + dim_head=hidden_size // num_heads, + heads=num_heads, + qk_norm="rms_norm" if qk_layernorm else None, + eps=eps, + bias=False, + out_bias=False, + processor=ErnieImageSingleStreamAttnProcessor(), + ) + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size) + + def forward( + self, + x, + rotary_pos_emb, + temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + ): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb + residual = x + x = self.adaLN_sa_ln(x) + x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + attn_out = self.self_attention(x, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) + x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) + residual = x + x = self.adaLN_mlp_ln(x) + x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype) + + +class ErnieImageAdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps) + self.linear = nn.Linear(hidden_size, hidden_size * 2) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + # Broadcast conditioning to sequence dimension + x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 3072, + num_attention_heads: int = 24, + num_layers: int = 24, + ffn_hidden_size: int = 8192, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 2560, + rope_theta: int = 256, + rope_axes_dim: Tuple[int, int, int] = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_layers = num_layers + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.text_in_dim = text_in_dim + + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size) + self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + self.layers = nn.ModuleList( + [ + ErnieImageSharedAdaLNBlock( + hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm + ) + for _ in range(num_layers) + ] + ) + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps) + self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + self.gradient_checkpointing = False + self.onload_device = None + + @property + def device(self): + # use self.x_embeddersince we ignore it in memory management + return next(self.x_embedder.parameters()).device + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + # encoder_hidden_states: List[torch.Tensor], + text_bth: torch.Tensor, + text_lens: torch.Tensor, + return_dict: bool = True, + ): + device = self.device + dtype = self.dtype + B, C, H, W = hidden_states.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + img_bsh = self.x_embedder(hidden_states).contiguous() # (B, N_img, H) + # text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + + x = torch.cat([img_bsh, text_bth], dim=1) # (B, S, H) + + # Position IDs + text_ids = ( + torch.cat( + [ + torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), + torch.zeros((B, Tmax, 2), device=device), + ], + dim=-1, + ) + if Tmax > 0 + else torch.zeros((B, 0, 3), device=device) + ) + grid_yx = torch.stack( + torch.meshgrid( + torch.arange(Hp, device=device, dtype=torch.float32), + torch.arange(Wp, device=device, dtype=torch.float32), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + image_ids = torch.cat( + [text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], + dim=-1, + ) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) + + # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention + valid_text = ( + torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + if Tmax > 0 + else torch.zeros((B, 0), device=device, dtype=torch.bool) + ) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] + + # AdaLN + sample = self.time_proj(timestep.to(dtype)) + sample = sample.to(dtype) + c = self.time_embedding(sample) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.unsqueeze(1) for t in self.adaLN_modulation(c).chunk(6, dim=-1) + ] # each (B, 1, H), broadcasts over sequence + for layer in self.layers: + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func( + layer, + x, + rotary_pos_emb, + temb, + attention_mask, + ) + else: + x = layer(x, rotary_pos_emb, temb, attention_mask) + x = self.final_norm(x, c).type_as(x) + patches = self.final_linear(x)[:, :N_img].contiguous() # (B, N_img, p*p*C) + output = ( + patches.view(B, Hp, Wp, p, p, self.out_channels) + .permute(0, 5, 1, 3, 2, 4) + .contiguous() + .view(B, self.out_channels, H, W) + ) + + return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/example_model/README.md b/ai-toolkit/extensions_built_in/diffusion_models/example_model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4e35726078c1bce7dc5fcf4632f21245fff6f3cf --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/example_model/README.md @@ -0,0 +1,241 @@ +# Example Model — a template for adding a new architecture to ai-toolkit + +This folder is a complete, heavily commented template for wiring a brand-new +diffusion model into ai-toolkit. It assumes the worst (and most common) case: +**diffusers does not have your model**, so you vendor the network and a minimal +sampling pipeline yourself. + +It is intentionally **not registered** — it never appears as a trainable arch. +It exists purely as a guide for people (and agents) adding image, editing, +video, or i2v models. + +## File map + +``` +example/ +├── README.md <- you are here +├── __init__.py <- exports ExampleModel (registration notes inside) +├── example_model.py <- the BaseModel subclass: every override documented +│ with exact inputs/outputs +└── src/ <- everything diffusers does NOT provide + ├── model.py <- a minimal DiT with the gradient-checkpointing pattern + └── pipeline.py <- a minimal embeds-only flow-matching sampler +``` + +## How a model gets registered + +1. `toolkit/util/get_model.py:get_all_models()` scans every package directly + under `extensions/` and `extensions_built_in/` for a module-level + `AI_TOOLKIT_MODELS` list. +2. For models in this folder, that list lives in + `extensions_built_in/diffusion_models/__init__.py` — import your class + there and append it to `AI_TOOLKIT_MODELS`. + (Alternatively, give your model its own folder under `extensions/` with its + own `AI_TOOLKIT_MODELS` list — see `extensions/z_image_pixel/`.) +3. The class attribute `arch` (e.g. `"example"`) is matched against + `model.arch` in the training config YAML to pick your class. +4. To expose it in the web UI, add an entry to + `ui/src/app/jobs/new/options.ts` (search for an existing arch like + `ideogram4` to copy the shape). + +Minimal config YAML to train it: + +```yaml +model: + arch: "example" + name_or_path: "/path/to/weights" # folder with transformer/, text_encoder/, + # tokenizer/, vae/ + quantize: true # optional: qfloat8 the transformer + quantize_te: true # optional: qfloat8 the text encoder +train: + gradient_checkpointing: true +``` + +## Lifecycle — who calls what, in order + +1. **Load** — `load_model()` builds the transformer, text encoder(s), + tokenizer(s), VAE and scheduler and stores them on `self`. Everything else + reads `self.model` / `self.vae` / `self.text_encoder`. +2. **Caching (optional)** — before training, the trainer may call + `encode_images()` per dataset image (latent cache) and + `get_prompt_embeds()` per caption (text-embed cache, saved via + `AdvancedPromptEmbeds.save`, one file per caption). +3. **Train step** (every step, see `extensions_built_in/sd_trainer/SDTrainer.py`): + 1. clean latents come from the cache or `encode_images()` + 2. noise + timestep are sampled; `add_noise()` (BaseModel) mixes them + 3. `condition_noisy_latents(noisy_latents, batch)` — your hook to inject + control/reference conditioning + 4. `get_noise_prediction(latent_model_input, timestep, text_embeddings)` — + the forward pass, under autograd + 5. loss = MSE(prediction, `get_loss_target(noise=..., batch=...)`) +4. **Sampling previews** — `generate_images()` (BaseModel) encodes each sample + prompt with `get_prompt_embeds()`, then calls your + `get_generation_pipeline()` once and `generate_single_image(...)` per + prompt. Your pipeline only ever receives **embeds, never text**. +5. **Saving** — full fine-tunes go through `save_model()`. LoRA files are + written by the network code, with your + `convert_lora_weights_before_save/load()` mapping keys to the public + convention (usually the `diffusion_model.` prefix). + +## Conventions to keep straight + +- **Pixels** are `(B, 3, H, W)` in `[-1, 1]` (control tensors arrive in + `[0, 1]` — multiply by 2 and subtract 1 before encoding). +- **Latents** are `(B, C, h, w)`; video latents are `(B, C, frames, h, w)`. +- **Timesteps** cross the BaseModel API on a `0..1000` scale where 1000 is + pure noise. Convert to your model's native convention inside + `get_noise_prediction` — and watch for models whose native time runs the + other way (t=1 = clean); flip and/or negate there (ideogram4 does both). +- **Flow-matching target** in this codebase is `noise - clean` + (`get_loss_target`), i.e. the velocity pointing from data to noise. +- `self.model` / `self.transformer` / `self.unet` are aliases for the same + thing on BaseModel. +- **`use_old_lokr_format = False`** — set this class attribute on every NEW + model. `BaseModel` defaults it to `True` purely for backwards-compatibility + with LoKr checkpoints trained before the format change; all new architectures + should use the new LoKr format. (Plain LoRA training is unaffected — this only + matters for `network.type: "lokr"`.) + +## AdvancedPromptEmbeds + +`toolkit/advanced_prompt_embeds.py`. The flexible container for text +conditioning, preferred for all new models over the older `PromptEmbeds`: + +- Every key holds a **list of tensors, one per batch item** + (`AdvancedPromptEmbeds(text_embeds=[t0, t1, ...])`). Store each item at its + natural length and pad to the batch max only at the model call + (`src/pipeline.py:pad_prompt_embeds`) — caches stay small and any prompts + can share a batch. +- **Keep each per-item tensor 2D `(L, D)`.** This is a hard requirement, not a + convention: `BaseModel.predict_noise` infers the text batch size from the + embed list, and it only counts the list as one-per-item when each tensor is + 2D (`len(text_embeds[0].shape) == 2`). A 3D per-item tensor is read as an + already-batched `(B, L, D)` and its *first axis* is taken as the batch size — + so a single 3D prompt of length `L` looks like a batch of `L`, and training + dies with *"Batch size of latents must be the same or half the batch size of + text embeddings."* If your conditioning has an extra axis (e.g. a stack of N + encoder layers, giving `(L, N, D)`), **flatten it into the feature axis** + (`(L, N*D)`) in `get_prompt_embeds` and **restore it** (`reshape(B, Lt, N, D)`) + in `get_noise_prediction` / the pipeline, right before the model call. +- Add as many keys as your model needs (`pooled_embeds`, image features, …). +- Keys that must not be dtype-cast (token ids, masks) go in + `embeds.frozen_dtype_keys`. +- CFG concat (`concat_prompt_embeds`), batch expansion, `.to()`, `.save()` / + `.load()` for the disk cache are all handled for you. + +If you ever change what `get_prompt_embeds` produces, bump the +`text_embedding_space_version` property so stale on-disk caches invalidate. + +## Gradient checkpointing + +With `train.gradient_checkpointing: true`, `BaseSDTrainProcess` calls +`model.enable_gradient_checkpointing()` if it exists, else sets +`model.gradient_checkpointing = True`. Your network re-runs each block under +`torch.utils.checkpoint.checkpoint(..., use_reentrant=False)` when the flag is +set **and** `torch.is_grad_enabled()` is true — never gate on `self.training`. +See `src/model.py` for the full pattern and rationale. + +## Quantization + +With `quantize: true`, `quantize_model` swaps every `nn.Linear` for an +`optimum.quanto` quantized one. Their matmul kernel **only accepts 2D or 3D +activations** (`assert activations.ndim in (2, 3)`) — a `Linear` you feed a 4D +tensor works fine in bf16 but throws once quantized. If your network applies a +`Linear` over a 4D tensor (e.g. projecting a `(B, L, D, N)` layer axis), +reshape to 3D for the call and back afterwards. + +Also watch out for **slow bf16 kernels on vendored components**: `Conv3d` has no +fast cuDNN bf16 path (it falls back to a slow one). If a frozen sub-model carries +a `Conv3d` you don't actually run — e.g. a vision tower's patch embed on a VL +text encoder — drop it (`text_encoder.model.visual = None`) to skip loading it; +if you must run one, consider running that component in fp16/fp32. + +## Attention backends (don't force flash-attn) + +Reference repos very often hard-code an attention kernel — `flash_attn`, +xformers, sage — and import it at module top level. **Do not carry that +requirement over.** ai-toolkit has to import and load your model on machines +where that package isn't installed (CPU boxes, headless CI, plain installs), so +a top-level `from flash_attn import ...` turns "load the model" into an +`ImportError`. + +The rule: + +- **Default to torch's built-in `F.scaled_dot_product_attention`** (the + "native" backend). It needs no extra dependency, runs on CPU and CUDA, and + already dispatches to a fused/flash kernel on supported hardware. `src/model.py` + does exactly this. +- **Make any other kernel OPTIONAL**, selected at runtime — never required at + import. The clean pattern: + 1. Guard the import so a missing package is a flag, not a crash: + ```python + try: + from flash_attn import flash_attn_varlen_func + _FLASH_ATTN_AVAILABLE = True + except ImportError: + flash_attn_varlen_func = None + _FLASH_ATTN_AVAILABLE = False + ``` + 2. Give each attention module an `attention_backend` flag (default + `"native"`) and **branch inside its forward** — `"flash"` runs the flash + kernel, anything else runs SDPA. + 3. Expose a `set_attention_backend("native"|"flash")` on the parent model + that validates the name, raises a clear error if `"flash"` is requested + while `_FLASH_ATTN_AVAILABLE` is `False`, and propagates the flag to every + attention module. + 4. Wire it to a config knob so it stays opt-in, e.g. + `model_kwargs.attention_backend: "flash"` read in `load_model`. + +Branch on a per-module **flag**, don't swap the processor/module instance: +attention modules that own trained q/k/v weights (joint/dual-stream blocks) +would lose those weights if you replaced them with a different instance. + +Worked implementations to copy: `../ideogram4/src/transformer.py` +(`set_attention_backend`, native+flash in one `Attention.forward`) and +`../boogu_image/src/attention_processor.py` (guarded import, per-processor +`attention_backend` flag, flash varlen branch alongside SDPA). + +## Adapting this template + +### Editing / instruct model (image in, image out) +- In `condition_noisy_latents`, encode `batch.control_tensor` + (`(B, 3, H, W)` in `[0, 1]`) with the VAE and attach it to the noisy + latents — extra channels (`torch.cat(..., dim=1)`) or extra sequence tokens. + Slice the prediction back down in `get_noise_prediction` before returning. + Reference: `../flux_kontext/flux_kontext.py`. +- If the text encoder must *see* the control image (VL encoders), set + `self.encode_control_in_text_embeddings = True`; `get_prompt_embeds` then + receives `control_images`. Reference: `../qwen_image/qwen_image_edit.py`. +- Multiple reference images: `self.has_multiple_control_images = True` + (`batch.control_tensor_list`). Reference: + `../qwen_image/qwen_image_edit_plus.py`. +- In `generate_single_image`, load `gen_config.ctrl_img` (a file path) and run + the same conditioning for previews. + +### Video model (t2v) +- Batches arrive as `(B, frames, 3, H, W)`; latents as + `(B, C, frames_latent, h, w)`. Override `encode_images`/`decode_latents` + for your video VAE (temporal compression means + `frames_latent = (frames - 1) // 4 + 1` for most VAEs). +- `gen_config.num_frames` drives previews; return a **list of PIL frames** + from `generate_single_image` and the harness saves a video. +- Reference: `../wan22/wan22_5b_model.py` and `../ltx2/`. + +### Image-to-video (i2v) +- Same as video, plus first-frame conditioning: in `get_noise_prediction` + take frame 0 from `batch.tensor` (declare `batch` in your signature to + receive it), encode it, and merge it into the latent input. For previews do + the same with `gen_config.ctrl_img`. +- Reference: `../wan22/wan22_14b_i2v_model.py` and + `toolkit/models/wan21/wan_utils.py:add_first_frame_conditioning`. + +### Other useful hooks (all on `toolkit/models/base_model.py:BaseModel`) +| Override | When you need it | +|---|---| +| `get_model_to_train()` | LoRA should attach to something other than `self.model` | +| `text_embedding_space_version` / `latent_space_version` | invalidate users' caches after a breaking change | +| `te_padding_side` | LLM text encoders that need left padding | +| `is_multistage`, `multistage_boundaries` | multi-expert models split by timestep range (`../wan22/wan22_14b_model.py`) | +| `load_training_adapter()` pattern | assistant LoRAs (de-distillation adapters), see `../z_image/z_image.py` | +| `get_latent_noise_from_latents()` | custom noise (default: `randn_like`) | +| `encode_audio()` | audio-conditioned models (`../ltx2/`) | diff --git a/ai-toolkit/extensions_built_in/diffusion_models/example_model/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/example_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd9402a551e5037d21d4b30cc8a97af70322a50 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/example_model/__init__.py @@ -0,0 +1,12 @@ +# This is a documentation-only TEMPLATE model. Start with README.md in this +# folder for the full guide to adding a new model architecture to ai-toolkit. +# +# It is intentionally NOT registered: the parent package +# (extensions_built_in/diffusion_models/__init__.py) does not import it, so it +# never shows up as a trainable arch. To register a real model, import its +# class there and append it to the AI_TOOLKIT_MODELS list. (Models can also +# live in their own folder under extensions/, which defines its own +# AI_TOOLKIT_MODELS list -- see extensions/z_image_pixel for a tiny example.) +from .example_model import ExampleModel + +__all__ = ["ExampleModel"] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/example_model/example_model.py b/ai-toolkit/extensions_built_in/diffusion_models/example_model/example_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a813f395dd4e72179c344a0664c2b73741223008 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/example_model/example_model.py @@ -0,0 +1,522 @@ +"""ExampleModel -- a fully documented template for adding a new model to ai-toolkit. + +Read README.md in this folder first for the big picture (lifecycle, data flow, +registration, and how to adapt this template into an edit / video / i2v model). + +Every override below documents: + - WHEN ai-toolkit calls it + - WHAT comes in (shapes, dtypes, scales) + - WHAT must come out + +The model itself is a made-up flow-matching DiT whose architecture lives in +./src/model.py and whose preview sampler lives in ./src/pipeline.py, simulating +the common case where diffusers does not ship your model and you vendor both. +""" + +import os +from typing import List, Optional + +import torch +import yaml +from safetensors.torch import load_file, save_file + +from diffusers import AutoencoderKL +from transformers import AutoTokenizer, AutoModel +from optimum.quanto import freeze + +from toolkit.accelerator import unwrap_model +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.util.quantize import quantize, get_qtype, quantize_model + +from .src.model import ExampleTransformer2DModel +from .src.pipeline import ExamplePipeline, pad_prompt_embeds + + +# Config for the training/sampling noise scheduler. ai-toolkit's flow-matching +# models all use CustomFlowMatchEulerDiscreteScheduler; ``shift`` warps the +# timestep distribution toward the high-noise end (bigger = more high-noise +# steps, typical for high-resolution models). +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + + +class ExampleModel(BaseModel): + # ``arch`` is the unique id that ties everything together: + # - ``model.arch: "example"`` in the training config YAML selects this class + # (resolved by toolkit/util/get_model.py:get_model_class) + # - it is the default cache key for text-embedding / latent caches + arch = "example" + + # ALL NEW MODELS should set this to False. ``BaseModel`` defaults it to True + # only for backwards-compatibility with already-released LoKr checkpoints; the + # newer LoKr weight format is the correct one for any new architecture. + use_old_lokr_format = False + + def __init__( + self, + device, # "cuda:0" etc. + model_config: ModelConfig, # the parsed ``model:`` section of the YAML + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + # --- flags the rest of the toolkit reads --- + # flow matching (velocity prediction) vs ddpm-style epsilon prediction + self.is_flow_matching = True + # transformer (DiT) vs unet: affects LoRA naming ("transformer." prefix) + self.is_transformer = True + # Class names of modules whose Linear layers get LoRA'd. Matched against + # type(module).__name__, so this must equal the class name in src/model.py. + self.target_lora_modules = ["ExampleTransformer2DModel"] + + # --- values used by our own overrides below --- + self.patch_size = 2 # transformer patch size (latent px per token) + self.vae_scale_factor = 8 # pixels per latent px (8x downsampling VAE) + # hard cap on prompt token length (truncation only -- embeds are stored + # per-sample at natural length, see get_prompt_embeds) + self.max_text_length = 512 + + # Other flags you may need (all default False, set in BaseModel.__init__): + # self.encode_control_in_text_embeddings = True + # -> get_prompt_embeds receives control_images (vision-language TEs + # that look at the control image, e.g. qwen_image_edit) + # self.has_multiple_control_images = True + # -> control images arrive as a list (qwen_image_edit_plus) + # self.use_raw_control_images = True + # -> control images are not resized to match the target image + # self.is_multistage = True + # -> model has multiple experts trained on timestep ranges (wan22 14b) + + @staticmethod + def get_train_scheduler(): + """Build the noise scheduler used for BOTH training and sampling. + + Called when loading the model, and again by the pipeline for every + preview run (a fresh instance, because scheduler state is mutable). + """ + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + """Pixel multiple that dataset resolution buckets must snap to. + + The data loader crops every image so width/height are divisible by + this. Latents are 1/8 the pixel size (VAE) and the transformer eats + 2x2 latent patches, so pixels must be divisible by 8 * 2 = 16. + """ + return self.vae_scale_factor * self.patch_size + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + def load_model(self): + """Load every component and store them on ``self``. + + Called once at startup. ``self.model_config`` is the ``model:`` section + of the training YAML; the fields used here: + - name_or_path: local folder (or HF repo) with the weights + - quantize / qtype: quantize the transformer (e.g. "qfloat8") + - quantize_te / qtype_te: quantize the text encoder + - low_vram: keep big components on CPU; your other overrides then + move them to GPU on demand (see the device checks below) + + MUST set, before returning: + self.model the trainable denoiser (transformer/unet) + self.vae the (frozen) VAE + self.text_encoder one module or a list of modules (frozen unless + training the TE) + self.tokenizer one tokenizer or a list, parallel to text_encoder + self.noise_scheduler from get_train_scheduler() + self.pipeline anything generate_single_image can use + """ + dtype = self.torch_dtype + self.print_and_status_update("Loading Example model") + # Expected layout (diffusers-style folder): + # /transformer/model.safetensors + # /text_encoder/ + /tokenizer/ (transformers format) + # /vae/ (diffusers AutoencoderKL) + model_path = self.model_config.name_or_path + + # --- transformer (the custom model from src/) --- + self.print_and_status_update("Loading transformer") + # Instantiate on the meta device (no RAM used), then materialize the + # real tensors straight from the checkpoint with assign=True. This + # avoids allocating the model twice. If your model has non-persistent + # buffers, rebuild them after this (see ideogram4.py for an example). + with torch.device("meta"): + transformer = ExampleTransformer2DModel() + state_dict = load_file( + os.path.join(model_path, "transformer", "model.safetensors") + ) + state_dict = {k: v.to(dtype) for k, v in state_dict.items()} + transformer.load_state_dict(state_dict, assign=True) + del state_dict + flush() # gc + empty cuda cache; call it after dropping anything big + + if self.model_config.quantize: + # quantize_model handles qtype selection, exclusions and device + # juggling, and leaves the model on CPU + self.print_and_status_update("Quantizing transformer") + quantize_model(self, transformer) + flush() + + if self.model_config.low_vram: + # leave it on CPU; get_noise_prediction moves it over when needed + transformer.to("cpu") + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + # For partial layer offloading support see MemoryManager.attach usage + # in ../ideogram4/ideogram4.py or ../z_image/z_image.py. + + # --- text encoder + tokenizer (stock transformers model) --- + self.print_and_status_update("Loading text encoder") + tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + text_encoder = AutoModel.from_pretrained( + model_path, subfolder="text_encoder", torch_dtype=dtype + ) + text_encoder.to(self.te_device_torch) + # the TE is frozen here; only set requires_grad if you train it + text_encoder.eval() + text_encoder.requires_grad_(False) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing text encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + # --- VAE --- + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") + vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + vae.eval() + vae.requires_grad_(False) + flush() + + # --- scheduler + store everything --- + self.noise_scheduler = ExampleModel.get_train_scheduler() + self.vae = vae + self.text_encoder = text_encoder # could be a list for multi-TE models + self.tokenizer = tokenizer # parallel list if multiple TEs + self.model = transformer # aliased as self.transformer / self.unet + self.pipeline = ExamplePipeline(self) + self.print_and_status_update("Model Loaded") + + # ------------------------------------------------------------------ + # Sampling (training previews) + # ------------------------------------------------------------------ + def get_generation_pipeline(self): + """Return a fresh pipeline for a round of preview sampling. + + Called once per sampling round by BaseModel.generate_images. Our + pipeline holds no state, so a new lightweight wrapper is enough. + """ + return ExamplePipeline(self) + + def generate_single_image( + self, + pipeline: ExamplePipeline, + gen_config: GenerateImageConfig, # one sample_prompts entry: width, + # height, seed, num_inference_steps, + # guidance_scale, ctrl_img, num_frames... + conditional_embeds: AdvancedPromptEmbeds, # already-encoded prompt + unconditional_embeds: AdvancedPromptEmbeds, # already-encoded negative prompt + generator: torch.Generator, # seeded with gen_config.seed + extra: dict, # adapter kwargs (controlnet etc.) + ): + """Render ONE preview image. + + The harness (BaseModel.generate_images) has already encoded the + prompts with get_prompt_embeds -- the pipeline never sees text. + + Returns a PIL.Image (or for video models a list of PIL frames). + """ + # low_vram: components may be parked on CPU between steps + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # snap requested size to the model's divisibility + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, # usually None; pre-made noise if set + generator=generator, + )[0] + return img + + # ------------------------------------------------------------------ + # Training hooks + # ------------------------------------------------------------------ + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + text_embeddings: AdvancedPromptEmbeds, + **kwargs, + ): + """The actual forward pass of the denoiser. Called every train step + (with grads) via BaseModel.predict_noise, and also by some adapters. + + in: + latent_model_input (B, C, h, w) noisy latents: the output of + add_noise(clean_latents, noise, timestep), after + condition_noisy_latents (channel-concat models + would see extra channels here). + For video models this is (B, C, frames, h, w). + timestep (B,) float on the 0..1000 scale, 1000 = pure noise + text_embeddings AdvancedPromptEmbeds for the batch; every key you + stored in get_prompt_embeds holds a list of B + tensors (cached per-sample embeds are expanded / + concatenated for you) + **kwargs may include ``batch`` (DataLoaderBatchDTO), + guidance_embedding_scale, adapter residuals, ... + only passed if your signature declares them + + out: + (B, C, h, w) the model prediction. For flow matching that is the + velocity in the same convention as get_loss_target (here: + noise - clean). Shape must match the TARGET latents -- if you + concatenated control channels/tokens in, slice them off before + returning (see ../flux_kontext/flux_kontext.py). + """ + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # toolkit timestep (0..1000) -> our model's flow time in [0, 1]. + # WATCH OUT: every model has its own time convention. If the original + # repo uses t=1 for clean images, flip it here (see + # ../ideogram4/src/pipeline.py predict_velocity for an example). + t01 = timestep.to(self.device_torch, dtype=torch.float32) / 1000.0 + + # per-sample embed lists -> padded batch tensor + attention mask + llm_features, text_mask = pad_prompt_embeds( + text_embeddings.text_embeds, self.device_torch, self.torch_dtype + ) + + noise_pred = self.model( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=t01, + encoder_hidden_states=llm_features, + attention_mask=text_mask, + ) + return noise_pred + + def get_prompt_embeds(self, prompt) -> AdvancedPromptEmbeds: + """Encode prompt text into whatever conditioning the model eats. + + Called for dataset captions (optionally cached to disk per caption), + for sample prompts, and for the empty string (unconditional). + + in: prompt a str or list[str] + out: AdvancedPromptEmbeds. Each key holds a LIST of tensors, one per + prompt, each at its natural (unpadded) length. Padding to the + batch max is deferred to get_noise_prediction / the pipeline, + which keeps caches small and lets any prompts share a batch. + + Each per-prompt tensor MUST be 2D ``(L, D)`` -- BaseModel infers the + text batch size from the list and only treats it as one-per-prompt + when the tensors are 2D; a 3D per-prompt tensor is misread as an + already-batched ``(B, L, D)`` and training fails with a latents-vs- + text batch-size mismatch. If your conditioning has an extra axis + (e.g. N stacked encoder layers -> ``(L, N, D)``), flatten it here + (``(L, N*D)``) and restore it (``reshape(B, Lt, N, D)``) at the + model call. + + You can store any number of keys (pooled embeds, image features, + ...). If a key must keep its dtype when everything else is cast + (masks, token ids), list it in ``embeds.frozen_dtype_keys``. + + NOTE: if you change how embeddings are computed after release, bump + ``text_embedding_space_version`` (a property on BaseModel) to + invalidate users' on-disk caches. + """ + if isinstance(prompt, str): + prompt = [prompt] + + # low_vram support: TE might be parked on CPU + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + + embeds_list = [] + for p in prompt: + tokens = self.tokenizer( + p, + truncation=True, + max_length=self.max_text_length, + return_tensors="pt", + ).to(self.text_encoder.device) + # no padding: encode each prompt at its own length + with torch.no_grad(): + output = self.text_encoder(**tokens, output_hidden_states=True) + # (L, D) -- drop the batch dim, one tensor per prompt + embeds_list.append(output.last_hidden_state[0].to(self.torch_dtype)) + + return AdvancedPromptEmbeds(text_embeds=embeds_list) + + def get_loss_target(self, *args, **kwargs): + """The ground-truth tensor the prediction is MSE'd against. + + kwargs: noise (B, C, h, w), batch (DataLoaderBatchDTO with .latents = + the clean latents), timesteps. For flow matching the velocity target + is noise - clean. Must be detached. + """ + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def condition_noisy_latents( + self, latents: torch.Tensor, batch + ) -> torch.Tensor: + """Optional hook: modify noisy latents before the model sees them. + + Called every train step right after noise is added. This is THE hook + for editing / inpainting / i2v models that feed reference latents in + alongside the noisy target (the reference is concatenated here, then + consumed -- and sliced off the prediction -- in get_noise_prediction). + + in: latents (B, C, h, w) noisy latents + batch DataLoaderBatchDTO -- batch.control_tensor holds the + control image(s) as (B, 3, H, W) in [0, 1] when the + dataset config has a control_path + out: latents, conditioned (return .detach()'d -- no grads here) + + This base text-to-image model needs nothing, so it passes through. + Real examples: ../flux_kontext/flux_kontext.py (concat control latents + as extra tokens), ../qwen_image/qwen_image_edit.py. + """ + return latents + + # ------------------------------------------------------------------ + # VAE encode / decode + # ------------------------------------------------------------------ + # BaseModel.encode_images / decode_latents already handle a diffusers + # AutoencoderKL (scaling_factor / shift_factor) and would work unchanged + # for this model. They are overridden here anyway to document the + # contract, since custom VAEs (or latent normalization, patchified + # latents, video VAEs...) usually need it. + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + """Pixels -> latents. Used for latent caching and for control images. + + in: image_list list of (3, H, W) tensors -- or a (B, 3, H, W) batch -- + with values in [-1, 1], already crop/bucket-sized + out: (B, C, h, w) latents, normalized the way the transformer expects + (for AutoencoderKL: (z - shift_factor) * scaling_factor) + """ + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(self.vae_device_torch) + + if isinstance(image_list, list): + images = torch.stack(image_list, dim=0) + else: + images = image_list + images = images.to(device, dtype=dtype) + + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config["shift_factor"] or 0 + latents = (latents - shift) * self.vae.config["scaling_factor"] + return latents.to(device, dtype=dtype) + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + """Latents -> pixels. Used when rendering previews. + + in: (B, C, h, w) latents in the normalized space encode_images produces + out: (B, 3, H, W) images in [-1, 1] + """ + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(self.vae_device_torch) + + latents = latents.to(device, dtype=dtype) + shift = self.vae.config["shift_factor"] or 0 + latents = latents / self.vae.config["scaling_factor"] + shift + return self.vae.decode(latents).sample + + # ------------------------------------------------------------------ + # Saving / bookkeeping + # ------------------------------------------------------------------ + def get_model_has_grad(self): + """True only if the base denoiser weights themselves require grad + (full fine-tune). LoRA training: False. Used to save/restore device + and grad state around sampling.""" + return False + + def get_te_has_grad(self): + """Same as above for the text encoder.""" + return False + + def save_model(self, output_path, meta, save_dtype): + """Save the FULL model (fine-tune checkpoints; LoRA saving is handled + elsewhere and only consults convert_lora_weights_before_save). + + ``output_path`` is a directory (no extension). Save in whatever layout + load_model can read back; include aitk_meta.yaml for provenance. + """ + transformer: ExampleTransformer2DModel = unwrap_model(self.model) + os.makedirs(os.path.join(output_path, "transformer"), exist_ok=True) + state_dict = { + k: v.clone().to("cpu", dtype=save_dtype) + for k, v in transformer.state_dict().items() + } + save_file( + state_dict, os.path.join(output_path, "transformer", "model.safetensors") + ) + with open(os.path.join(output_path, "aitk_meta.yaml"), "w") as f: + yaml.dump(meta, f) + + def get_base_model_version(self): + """Free-form version string written into LoRA metadata so other tools + can identify the base model family.""" + return "example.1" + + def get_transformer_block_names(self) -> Optional[List[str]]: + """Attribute name(s) on self.model that hold the repeated transformer + blocks (a ModuleList). Used for LoRA block targeting; must match the + attribute in src/model.py.""" + return ["blocks"] + + def convert_lora_weights_before_save(self, state_dict): + """Map internal LoRA keys to the ecosystem-standard naming right before + the .safetensors is written. Most modern models ship LoRAs with a + ``diffusion_model.`` prefix (ComfyUI convention); internally ai-toolkit + uses ``transformer.``.""" + return { + k.replace("transformer.", "diffusion_model."): v + for k, v in state_dict.items() + } + + def convert_lora_weights_before_load(self, state_dict): + """Inverse of the above, applied when resuming from a saved LoRA.""" + return { + k.replace("diffusion_model.", "transformer."): v + for k, v in state_dict.items() + } diff --git a/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3874739564b68596f9e76e32ea9d3b632b622725 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/__init__.py @@ -0,0 +1,4 @@ +# Everything diffusers does NOT provide for your model lives in src/: +# the network architecture and a minimal sampling pipeline. +from .model import ExampleTransformer2DModel +from .pipeline import ExamplePipeline, pad_prompt_embeds diff --git a/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/model.py b/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..30fc34827cd3838dfcedd13b825915d3a5722d70 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/model.py @@ -0,0 +1,274 @@ +"""A minimal diffusion transformer (DiT) used by the example model extension. + +This file stands in for the situation where diffusers does NOT have your model. +You vendor the architecture yourself inside your extension's ``src/`` folder and +load the weights manually in your model class (see ``../example_model.py``). + +The architecture here is intentionally tiny and boring: + + latents (B, C, h, w) + -> patchify with a strided conv (B, N_img, hidden) + text embeds (B, L, text_dim) + -> linear projection (B, L, hidden) + concat [text | image] into one joint sequence (B, L + N_img, hidden) + -> N transformer blocks (self attention + mlp, adaLN-zero + modulated by the timestep embedding) + -> final modulated norm + linear + take only the image tokens and unpatchify back to (B, C, h, w) + +Real models add RoPE position embeddings, fancier attention, guidance +embeddings, etc. For real-world reference implementations in this repo see: + - ../../chroma/src/model.py (flux-style double/single stream blocks) + - ../../ernie_image/transformer.py (diffusers ModelMixin based) + - ../../ideogram4/src/transformer.py (packed single-sequence model) + +GRADIENT CHECKPOINTING +====================== +ai-toolkit enables gradient checkpointing on your model from +``jobs/process/BaseSDTrainProcess.py`` which does, in order of preference: + + if hasattr(unet, 'enable_gradient_checkpointing'): + unet.enable_gradient_checkpointing() + elif hasattr(unet, 'gradient_checkpointing'): + unet.gradient_checkpointing = True + +So a custom model only needs: + 1. a ``self.gradient_checkpointing`` flag (default False) + 2. (optionally) an ``enable_gradient_checkpointing()`` method + 3. to wrap each transformer block call in ``torch.utils.checkpoint.checkpoint`` + when the flag is set AND grads are enabled. + +IMPORTANT: gate on ``torch.is_grad_enabled()``, NOT on ``self.training``. +Sampling runs under ``torch.no_grad()`` where checkpointing is pure overhead, +and some training setups (e.g. certain adapters) run the module in eval mode +while still needing gradients. ``torch.is_grad_enabled()`` handles both. +""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + + +def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor: + """Standard sinusoidal embedding. + + in: t (B,) float tensor, the flow-matching time in [0, 1] (1 = pure noise) + out: emb (B, dim) + + We scale t by 1000 before embedding so the sinusoids get a useful range, + the same trick flux and friends use. + """ + t = t.float() * 1000.0 + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None] * freqs[None] + return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + +class ExampleTransformerBlock(nn.Module): + """One DiT block: adaLN-zero modulated self-attention + MLP. + + in: x (B, S, hidden) the joint [text | image] token sequence + temb (B, hidden) the timestep embedding + attn_mask (B, 1, 1, S) bool, True = attend, False = padding + out: x (B, S, hidden) + """ + + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0): + super().__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.qkv = nn.Linear(hidden_size, hidden_size * 3) + self.proj = nn.Linear(hidden_size, hidden_size) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden = int(hidden_size * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden, hidden_size), + ) + + # adaLN-zero: timestep embedding -> shift/scale/gate for attn and mlp. + # Zero-init so the block starts as identity (standard DiT trick). + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size) + ) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + + def forward(self, x: torch.Tensor, temb: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + b, s, d = x.shape + shift_a, scale_a, gate_a, shift_m, scale_m, gate_m = ( + self.adaLN_modulation(temb).unsqueeze(1).chunk(6, dim=-1) + ) # each (B, 1, hidden), broadcasts over the sequence + + # --- attention --- + # ALWAYS default to torch's built-in scaled_dot_product_attention so the + # model runs with no extra dependency. If the reference repo you are + # porting hard-codes flash-attn (or xformers, sage, ...), do NOT carry + # that requirement over -- make it OPTIONAL. The clean pattern is a + # per-module ``attention_backend`` flag toggled in bulk from the parent + # model (e.g. ``set_attention_backend("flash")``), branching to the + # flash kernel only when explicitly selected AND the package is present. + # See ../../ideogram4/src/transformer.py and ../../boogu_image/src for + # working "native" (SDPA) + optional "flash" implementations. + h = self.norm1(x) * (1 + scale_a) + shift_a + q, k, v = self.qkv(h).chunk(3, dim=-1) + q = q.view(b, s, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(b, s, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(b, s, self.num_heads, self.head_dim).transpose(1, 2) + h = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + h = h.transpose(1, 2).reshape(b, s, d) + x = x + gate_a * self.proj(h) + + # --- mlp --- + h = self.norm2(x) * (1 + scale_m) + shift_m + x = x + gate_m * self.mlp(h) + return x + + +class ExampleTransformer2DModel(nn.Module): + """The denoiser. Plain ``nn.Module`` on purpose. + + You could also subclass ``diffusers.ModelMixin``/``ConfigMixin`` (see + ../../ernie_image/transformer.py) to get ``save_pretrained``, + ``_gradient_checkpointing_func`` etc. for free, but a plain module shows + exactly what ai-toolkit actually requires, which is very little: + + - a forward pass + - ``device`` / ``dtype`` properties (BaseModel reads ``self.model.device`` + and ``self.model.dtype`` in a few places, e.g. save_device_state) + - the gradient checkpointing flag described in the module docstring + + NOTE: the class NAME matters. ``ExampleModel.target_lora_modules`` lists + "ExampleTransformer2DModel" -- that string is matched against module class + names when deciding where to attach LoRA layers. + """ + + def __init__( + self, + in_channels: int = 16, # VAE latent channels + out_channels: int = 16, # predicted velocity has the same channels + patch_size: int = 2, # latent pixels per token side + hidden_size: int = 1024, + num_heads: int = 16, + num_layers: int = 12, + text_dim: int = 2048, # width of the text encoder hidden states + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + + # latent (B, C, h, w) -> image tokens (B, N_img, hidden) + self.x_embedder = nn.Conv2d( + in_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + # text encoder hidden states -> model width + self.text_proj = nn.Linear(text_dim, hidden_size) + # sinusoidal timestep embedding -> mlp + self.t_embedder = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + # ``blocks`` is the repeated-layer ModuleList. The attribute name is + # what ExampleModel.get_transformer_block_names() returns, which the + # LoRA code uses for block targeting / "transformer only" training. + self.blocks = nn.ModuleList( + [ + ExampleTransformerBlock(hidden_size, num_heads) + for _ in range(num_layers) + ] + ) + + # final adaLN + projection back to patch pixels, zero-init so the + # untrained model predicts zeros. + self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.adaLN_out = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size)) + self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + nn.init.zeros_(self.adaLN_out[-1].weight) + nn.init.zeros_(self.adaLN_out[-1].bias) + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + + # gradient checkpointing flag, flipped on by the trainer (see module + # docstring). Off by default so inference pays no cost. + self.gradient_checkpointing = False + + # the trainer prefers this method if it exists + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, # (B, C, h, w) noisy latents + timestep: torch.Tensor, # (B,) flow time in [0, 1], 1 = pure noise + encoder_hidden_states: torch.Tensor, # (B, L, text_dim) padded text features + attention_mask: torch.Tensor, # (B, L) 1 = real text token, 0 = padding + ) -> torch.Tensor: + """Predict the flow-matching velocity. + + out: (B, C, h, w) velocity in the ai-toolkit convention + (noise - clean), matching ExampleModel.get_loss_target(). + """ + b, c, h, w = hidden_states.shape + p = self.patch_size + gh, gw = h // p, w // p + n_img = gh * gw + + # tokens + img = self.x_embedder(hidden_states) # (B, hidden, gh, gw) + img = img.flatten(2).transpose(1, 2) # (B, N_img, hidden) + txt = self.text_proj(encoder_hidden_states) # (B, L, hidden) + x = torch.cat([txt, img], dim=1) # (B, L + N_img, hidden) + + # timestep conditioning + temb = self.t_embedder(timestep_embedding(timestep, self.hidden_size)) + temb = temb.to(x.dtype) + + # joint attention mask: text padding is masked out, image tokens and + # real text tokens attend everywhere. (B, 1, 1, S) bool for sdpa. + img_mask = torch.ones(b, n_img, dtype=torch.bool, device=x.device) + attn_mask = torch.cat([attention_mask.bool(), img_mask], dim=1) + attn_mask = attn_mask[:, None, None, :] + + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + # Recompute this block's activations during backward instead + # of storing them -- trades compute for a big VRAM saving. + # use_reentrant=False is the modern, correct variant. + x = checkpoint(block, x, temb, attn_mask, use_reentrant=False) + else: + x = block(x, temb, attn_mask) + + # final modulation + project, keep only the image tokens + shift, scale = self.adaLN_out(temb).unsqueeze(1).chunk(2, dim=-1) + x = self.norm_out(x) * (1 + scale) + shift + x = self.proj_out(x)[:, -n_img:] # (B, N_img, p*p*C) + + # unpatchify back to the latent layout + x = x.view(b, gh, gw, p, p, self.out_channels) + x = x.permute(0, 5, 1, 3, 2, 4).reshape(b, self.out_channels, h, w) + return x diff --git a/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..30416e36d08a177df0c6d8e11b768e7c2b8e1d14 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/example_model/src/pipeline.py @@ -0,0 +1,158 @@ +"""A minimal sampling pipeline for the example model. + +ai-toolkit only uses your pipeline to render preview/sample images during +training (see BaseModel.generate_images -> ExampleModel.generate_single_image). +It does NOT need to be a diffusers DiffusionPipeline, and because ai-toolkit +always encodes the prompts itself (so it can cache embeds, apply trigger words, +run adapters, etc.) the pipeline never sees raw prompt strings -- only +already-encoded ``AdvancedPromptEmbeds``. + +So all a pipeline has to do is: + + 1. make starting noise + 2. loop the scheduler over timesteps, calling the transformer + 3. apply classifier-free guidance (cond vs uncond prediction) + 4. decode the final latents with the VAE and return PIL images + +The pattern of passing the whole BaseModel instance into the pipeline (rather +than individual components) is borrowed from ../../ideogram4/src/pipeline.py. +It keeps the pipeline tiny because it can reuse the model's scheduler factory, +``decode_latents`` and device/dtype bookkeeping. +""" + +from typing import List, Optional + +import torch +from PIL import Image +from diffusers.utils.torch_utils import randn_tensor + + +def pad_prompt_embeds( + embeds_list: List[torch.Tensor], + device: torch.device, + dtype: torch.dtype, +): + """Right-pad a list of per-sample text features into one batch tensor. + + in: embeds_list list (len B) of (L_i, D) tensors -- this is exactly what + ``AdvancedPromptEmbeds.text_embeds`` holds: one tensor per + batch item, each at its own natural length. + out: features (B, L_max, D) zero-padded on the right + mask (B, L_max) long, 1 = real token, 0 = padding + + Storing embeds unpadded per item and only padding at the model call is the + preferred pattern: cached embeds stay small, and items of very different + prompt lengths can share a batch. + """ + lengths = [e.shape[0] for e in embeds_list] + max_len = max(lengths) + dim = embeds_list[0].shape[-1] + batch_size = len(embeds_list) + + features = torch.zeros(batch_size, max_len, dim, device=device, dtype=dtype) + mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) + for i, e in enumerate(embeds_list): + n = e.shape[0] + features[i, :n] = e.to(device, dtype) + mask[i, :n] = 1 + return features, mask + + +class ExamplePipeline: + """Lightweight flow-matching sampler used for training previews.""" + + def __init__(self, model): + # ``model`` is the ExampleModel (a BaseModel subclass), giving us + # access to model.transformer, model.vae, model.decode_latents, etc. + self.model = model + + @property + def device(self): + return self.model.device_torch + + def to(self, *args, **kwargs): + # BaseModel.generate_images may call pipeline.to(device); we manage + # devices through the model itself, so this is a no-op. + return self + + def set_progress_bar_config(self, **kwargs): + # called by the sampler harness (inside a try/except, so optional); + # diffusers pipelines use it to silence tqdm. Nothing to do here. + pass + + @torch.no_grad() + def __call__( + self, + # AdvancedPromptEmbeds with key ``text_embeds`` (list of (L, D) tensors) + conditional_embeds, + unconditional_embeds, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 25, + guidance_scale: float = 4.0, + latents: Optional[torch.Tensor] = None, # pre-made noise, usually None + generator: Optional[torch.Generator] = None, # seeded RNG for reproducible samples + **kwargs, + ) -> List[Image.Image]: + model = self.model + device = model.device_torch + dtype = model.torch_dtype + transformer = model.transformer + + # Always sample with a FRESH scheduler. The training scheduler is + # stateful; mutating it mid-training would corrupt the train step. + scheduler = model.get_train_scheduler() + scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = scheduler.timesteps # 1000 -> 0 scale + + # pixel size -> latent size (VAE downsample only; the transformer + # patchifies internally so latents stay unpacked here) + gh = height // model.vae_scale_factor + gw = width // model.vae_scale_factor + + do_cfg = unconditional_embeds is not None and guidance_scale != 1.0 + + # 1. starting noise (keep it float32; cast per model call) + if latents is None: + shape = (1, transformer.in_channels, gh, gw) + latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = latents.to(device, dtype=torch.float32) + + # 2. pad the per-item embed lists into batch tensors once, up front + cond_feats, cond_mask = pad_prompt_embeds(conditional_embeds.text_embeds, device, dtype) + if do_cfg: + uncond_feats, uncond_mask = pad_prompt_embeds(unconditional_embeds.text_embeds, device, dtype) + + # 3. denoising loop + for t in timesteps: + # scheduler timesteps are on a 0-1000 scale; the transformer wants + # flow time in [0, 1] with 1 = pure noise + t01 = (t / 1000.0).to(device).expand(latents.shape[0]) + + v_cond = transformer( + hidden_states=latents.to(dtype), + timestep=t01, + encoder_hidden_states=cond_feats, + attention_mask=cond_mask, + ) + if do_cfg: + v_uncond = transformer( + hidden_states=latents.to(dtype), + timestep=t01, + encoder_hidden_states=uncond_feats, + attention_mask=uncond_mask, + ) + # classifier-free guidance: push the prediction away from the + # unconditional (negative prompt) direction + v = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v = v_cond + + latents = scheduler.step(v.to(torch.float32), t, latents, return_dict=False)[0] + + # 4. decode latents -> images in [-1, 1] -> uint8 PIL + images = model.decode_latents(latents, device=device, dtype=dtype) + images = images.float().clamp(-1.0, 1.0) + images = ((images + 1.0) * 127.5).round().to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy() + return [Image.fromarray(arr) for arr in images] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/f_light/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/f_light/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a438f93b59b1dd0caa60d10e2da615652c036c3 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/f_light/__init__.py @@ -0,0 +1 @@ +from .f_light import FLiteModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/f_light/f_light.py b/ai-toolkit/extensions_built_in/diffusion_models/f_light/f_light.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc4f5a09ad7bddf7684d4a237f5da954a5f5ae7 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/f_light/f_light.py @@ -0,0 +1,295 @@ +import os +from typing import TYPE_CHECKING + +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel +from .src import FLitePipeline, DiT + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + + +class FLiteModel(BaseModel): + arch = "f-lite" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['DiT'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 16 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + extras_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + transformer = DiT.from_pretrained( + model_path, + subfolder="dit_model", + torch_dtype=dtype, + ) + + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder", torch_dtype=dtype + ) + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder) + flush() + + self.noise_scheduler = FLiteModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=dtype + ) + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: FLitePipeline = FLitePipeline( + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + dit_model=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.dit_model = transformer + pipe.transformer = transformer + pipe.scheduler = self.noise_scheduler, + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = FLiteModel.get_train_scheduler() + # it has built in scheduler. Basically euler flowmatching + pipeline = FLitePipeline( + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + dit_model=unwrap_model(self.transformer) + ) + pipeline.transformer = pipeline.dit_model + pipeline.scheduler = scheduler + + return pipeline + + def generate_single_image( + self, + pipeline: FLitePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + latent_model_input.to( + self.device_torch, cast_dtype + ), + text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + timestep / 1000, + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, negative_embeds = self.pipeline.encode_prompt( + prompt=prompts, + negative_prompt=None, + device=self.text_encoder[0].device, + dtype=self.torch_dtype, + ) + + pe = PromptEmbeds(prompt_embeds) + + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + + def get_te_has_grad(self): + # return from a weight if it has grad + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: DiT = unwrap_model(self.model) + # diffusers + # only save the unet + transformer: DiT = unwrap_model(self.transformer) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'dit_model'), + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + # return (noise - batch.latents).detach() + return (batch.latents - noise).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "f-lite" + + def get_stepped_pred(self, pred, noise): + # just used for DFE support + latents = pred + noise + return latents diff --git a/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e51652f708e0657d6048c27161ec28864f57e54 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/__init__.py @@ -0,0 +1,5 @@ +from .pipeline import FLitePipeline, FLitePipelineOutput, APGConfig +from .model import DiT + + +__all__ = ["FLitePipeline", "FLitePipelineOutput", "APGConfig", "DiT"] \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/model.py b/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..903d49280db8a614a466470f6acf0e3f49539b01 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/model.py @@ -0,0 +1,456 @@ +# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/model.py but modified slightly + +import math + +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import rearrange +from peft import get_peft_model_state_dict, set_peft_model_state_dict +from torch import nn + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + return embedding + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6, trainable=False): + super().__init__() + self.eps = eps + if trainable: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, x): + x_dtype = x.dtype + x = x.float() + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + if self.weight is not None: + return (x * norm * self.weight).to(dtype=x_dtype) + else: + return (x * norm).to(dtype=x_dtype) + + +class QKNorm(nn.Module): + """Normalizing the query and the key independently, as Flux proposes""" + + def __init__(self, dim, trainable=False): + super().__init__() + self.query_norm = RMSNorm(dim, trainable=trainable) + self.key_norm = RMSNorm(dim, trainable=trainable) + + def forward(self, q, k): + q = self.query_norm(q) + k = self.key_norm(k) + return q, k + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + is_self_attn=True, + cross_attn_input_size=None, + residual_v=False, + dynamic_softmax_temperature=False, + ): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.is_self_attn = is_self_attn + self.residual_v = residual_v + self.dynamic_softmax_temperature = dynamic_softmax_temperature + + if is_self_attn: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.context_kv = nn.Linear(cross_attn_input_size, dim * 2, bias=qkv_bias) + + self.proj = nn.Linear(dim, dim, bias=False) + + if residual_v: + self.lambda_param = nn.Parameter(torch.tensor(0.5).reshape(1)) + + self.qk_norm = QKNorm(self.head_dim) + + def forward(self, x, context=None, v_0=None, rope=None): + if self.is_self_attn: + qkv = self.qkv(x) + qkv = rearrange(qkv, "b l (k h d) -> k b h l d", k=3, h=self.num_heads) + q, k, v = qkv.unbind(0) + + if self.residual_v and v_0 is not None: + v = self.lambda_param * v + (1 - self.lambda_param) * v_0 + + if rope is not None: + # print(q.shape, rope[0].shape, rope[1].shape) + q = apply_rotary_emb(q, rope[0], rope[1]) + k = apply_rotary_emb(k, rope[0], rope[1]) + + # https://arxiv.org/abs/2306.08645 + # https://arxiv.org/abs/2410.01104 + # ratioonale is that if tokens get larger, categorical distribution get more uniform + # so you want to enlargen entropy. + + token_length = q.shape[2] + if self.dynamic_softmax_temperature: + ratio = math.sqrt(math.log(token_length) / math.log(1040.0)) # 1024 + 16 + k = k * ratio + q, k = self.qk_norm(q, k) + + else: + q = rearrange(self.q(x), "b l (h d) -> b h l d", h=self.num_heads) + kv = rearrange( + self.context_kv(context), + "b l (k h d) -> k b h l d", + k=2, + h=self.num_heads, + ) + k, v = kv.unbind(0) + q, k = self.qk_norm(q, k) + + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b h l d -> b l (h d)") + x = self.proj(x) + return x, v if self.is_self_attn else None + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + cross_attn_input_size, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + residual_v=False, + dynamic_softmax_temperature=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = RMSNorm(hidden_size, trainable=qkv_bias) + self.self_attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + is_self_attn=True, + residual_v=residual_v, + dynamic_softmax_temperature=dynamic_softmax_temperature, + ) + + if cross_attn_input_size is not None: + self.norm2 = RMSNorm(hidden_size, trainable=qkv_bias) + self.cross_attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + is_self_attn=False, + cross_attn_input_size=cross_attn_input_size, + dynamic_softmax_temperature=dynamic_softmax_temperature, + ) + else: + self.norm2 = None + self.cross_attn = None + + self.norm3 = RMSNorm(hidden_size, trainable=qkv_bias) + mlp_hidden = int(hidden_size * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden), + nn.GELU(), + nn.Linear(mlp_hidden, hidden_size), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True)) + + self.adaLN_modulation[-1].weight.data.zero_() + self.adaLN_modulation[-1].bias.data.zero_() + + # @torch.compile(mode='reduce-overhead') + def forward(self, x, context, c, v_0=None, rope=None): + ( + shift_sa, + scale_sa, + gate_sa, + shift_ca, + scale_ca, + gate_ca, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(c).chunk(9, dim=1) + + scale_sa = scale_sa[:, None, :] + scale_ca = scale_ca[:, None, :] + scale_mlp = scale_mlp[:, None, :] + + shift_sa = shift_sa[:, None, :] + shift_ca = shift_ca[:, None, :] + shift_mlp = shift_mlp[:, None, :] + + gate_sa = gate_sa[:, None, :] + gate_ca = gate_ca[:, None, :] + gate_mlp = gate_mlp[:, None, :] + + norm_x = self.norm1(x.clone()) + norm_x = norm_x * (1 + scale_sa) + shift_sa + attn_out, v = self.self_attn(norm_x, v_0=v_0, rope=rope) + x = x + attn_out * gate_sa + + if self.norm2 is not None: + norm_x = self.norm2(x) + norm_x = norm_x * (1 + scale_ca) + shift_ca + x = x + self.cross_attn(norm_x, context)[0] * gate_ca + + norm_x = self.norm3(x) + norm_x = norm_x * (1 + scale_mlp) + shift_mlp + x = x + self.mlp(norm_x) * gate_mlp + + return x, v + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_channels=3, embed_dim=768): + super().__init__() + self.patch_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.patch_size = patch_size + + def forward(self, x): + B, C, H, W = x.shape + x = self.patch_proj(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x + + +class TwoDimRotary(torch.nn.Module): + def __init__(self, dim, base=10000, h=256, w=256): + super().__init__() + self.inv_freq = torch.FloatTensor([1.0 / (base ** (i / dim)) for i in range(0, dim, 2)]) + self.h = h + self.w = w + + t_h = torch.arange(h, dtype=torch.float32) + t_w = torch.arange(w, dtype=torch.float32) + + freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) # h, 1, d / 2 + freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) # 1, w, d / 2 + freqs_h = freqs_h.repeat(1, w, 1) # h, w, d / 2 + freqs_w = freqs_w.repeat(h, 1, 1) # h, w, d / 2 + freqs_hw = torch.cat([freqs_h, freqs_w], 2) # h, w, d + + self.register_buffer("freqs_hw_cos", freqs_hw.cos()) + self.register_buffer("freqs_hw_sin", freqs_hw.sin()) + + def forward(self, x, height_width=None, extend_with_register_tokens=0): + if height_width is not None: + this_h, this_w = height_width + else: + this_hw = x.shape[1] + this_h, this_w = int(this_hw**0.5), int(this_hw**0.5) + + cos = self.freqs_hw_cos[0 : this_h, 0 : this_w] + sin = self.freqs_hw_sin[0 : this_h, 0 : this_w] + + cos = cos.clone().reshape(this_h * this_w, -1) + sin = sin.clone().reshape(this_h * this_w, -1) + + # append N of zero-attn tokens + if extend_with_register_tokens > 0: + cos = torch.cat( + [ + torch.ones(extend_with_register_tokens, cos.shape[1]).to(cos.device), + cos, + ], + 0, + ) + sin = torch.cat( + [ + torch.zeros(extend_with_register_tokens, sin.shape[1]).to(sin.device), + sin, + ], + 0, + ) + + return cos[None, None, :, :], sin[None, None, :, :] # [1, 1, T + N, Attn-dim] + + +def apply_rotary_emb(x, cos, sin): + orig_dtype = x.dtype + x = x.to(dtype=torch.float32) + assert x.ndim == 4 # multihead attention + d = x.shape[3] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat([y1, y2], 3).to(dtype=orig_dtype) + + +class DiT(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels=4, + patch_size=2, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + cross_attn_input_size=128, + residual_v=False, + train_bias_and_rms=True, + use_rope=True, + gradient_checkpoint=False, + dynamic_softmax_temperature=False, + rope_base=10000, + ): + super().__init__() + + self.patch_embed = PatchEmbed(patch_size, in_channels, hidden_size) + + if use_rope: + self.rope = TwoDimRotary(hidden_size // (2 * num_heads), base=rope_base, h=512, w=512) + else: + self.positional_embedding = nn.Parameter(torch.zeros(1, 2048, hidden_size)) + + self.register_tokens = nn.Parameter(torch.randn(1, 16, hidden_size)) + + self.time_embed = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.SiLU(), + nn.Linear(4 * hidden_size, hidden_size), + ) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + cross_attn_input_size=cross_attn_input_size, + residual_v=residual_v, + qkv_bias=train_bias_and_rms, + dynamic_softmax_temperature=dynamic_softmax_temperature, + ) + for _ in range(depth) + ] + ) + + self.final_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + self.final_norm = RMSNorm(hidden_size, trainable=train_bias_and_rms) + self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * in_channels) + nn.init.zeros_(self.final_modulation[-1].weight) + nn.init.zeros_(self.final_modulation[-1].bias) + nn.init.zeros_(self.final_proj.weight) + nn.init.zeros_(self.final_proj.bias) + self.paramstatus = {} + for n, p in self.named_parameters(): + self.paramstatus[n] = { + "shape": p.shape, + "requires_grad": p.requires_grad, + } + self.gradient_checkpointing = False + + def save_lora_weights(self, save_directory): + """Save LoRA weights to a file""" + lora_state_dict = get_peft_model_state_dict(self) + torch.save(lora_state_dict, f"{save_directory}/lora_weights.pt") + + def load_lora_weights(self, load_directory): + """Load LoRA weights from a file""" + lora_state_dict = torch.load(f"{load_directory}/lora_weights.pt") + set_peft_model_state_dict(self, lora_state_dict) + + @apply_forward_hook + def forward(self, x, context, timesteps): + b, c, h, w = x.shape + x = self.patch_embed(x) # b, T, d + + x = torch.cat([self.register_tokens.repeat(b, 1, 1), x], 1) # b, T + N, d + + if self.config.use_rope: + cos, sin = self.rope( + x, + extend_with_register_tokens=16, + height_width=(h // self.config.patch_size, w // self.config.patch_size), + ) + else: + x = x + self.positional_embedding.repeat(b, 1, 1)[:, : x.shape[1], :] + cos, sin = None, None + + t_emb = timestep_embedding(timesteps * 1000, self.config.hidden_size).to(x.device, dtype=x.dtype) + t_emb = self.time_embed(t_emb) + + v_0 = None + + for _idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + x, v = self._gradient_checkpointing_func( + block, + x, + context, + t_emb, + v_0, + (cos, sin) + ) + else: + x, v = block(x, context, t_emb, v_0, (cos, sin)) + if v_0 is None: + v_0 = v + + x = x[:, 16:, :] + final_shift, final_scale = self.final_modulation(t_emb).chunk(2, dim=1) + x = self.final_norm(x) + x = x * (1 + final_scale[:, None, :]) + final_shift[:, None, :] + x = self.final_proj(x) + + x = rearrange( + x, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + h=h // self.config.patch_size, + w=w // self.config.patch_size, + p1=self.config.patch_size, + p2=self.config.patch_size, + ) + return x + + +if __name__ == "__main__": + model = DiT( + in_channels=4, + patch_size=2, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + cross_attn_input_size=128, + residual_v=False, + train_bias_and_rms=True, + use_rope=True, + ).cuda() + print( + model( + torch.randn(1, 4, 64, 64).cuda(), + torch.randn(1, 37, 128).cuda(), + torch.tensor([1.0]).cuda(), + ) + ) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..69fc67b93f5d58d0b234c23f2c557c0e47a26947 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/f_light/src/pipeline.py @@ -0,0 +1,308 @@ +# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/pipeline.py but modified slightly +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import AutoencoderKL, DiffusionPipeline +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image +from torch import FloatTensor +from tqdm.auto import tqdm +from transformers import T5EncoderModel, T5TokenizerFast + + + +logger = logging.getLogger(__name__) + + +@dataclass +class APGConfig: + """APG (Augmented Parallel Guidance) configuration""" + + enabled: bool = True + orthogonal_threshold: float = 0.03 + + +@dataclass +class FLitePipelineOutput(BaseOutput): + """ + Output class for FLitePipeline pipeline. + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[Image.Image], np.ndarray] + + +class FLitePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using F-Lite model. + This model inherits from [`DiffusionPipeline`]. + """ + + model_cpu_offload_seq = "text_encoder->dit_model->vae" + + dit_model: torch.nn.Module + vae: AutoencoderKL + text_encoder: T5EncoderModel + tokenizer: T5TokenizerFast + _progress_bar_config: Dict[str, Any] + + def __init__( + self, dit_model: torch.nn.Module, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast + ): + super().__init__() + # Register all modules for the pipeline + # Access DiffusionPipeline's register_modules directly to avoid mypy error + DiffusionPipeline.register_modules( + self, dit_model=dit_model, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer + ) + + # Move models to channels last for better performance + # AutoencoderKL inherits from torch.nn.Module which has these methods + if hasattr(self.vae, "to"): + self.vae.to(memory_format=torch.channels_last) + if hasattr(self.vae, "requires_grad_"): + self.vae.requires_grad_(False) + if hasattr(self.text_encoder, "requires_grad_"): + self.text_encoder.requires_grad_(False) + + # Constants + self.vae_scale_factor = 8 + self.return_index = -8 # T5 hidden state index to use + + def enable_vae_slicing(self): + """Enable VAE slicing for memory efficiency.""" + if hasattr(self.vae, "enable_slicing"): + self.vae.enable_slicing() + + def enable_vae_tiling(self): + """Enable VAE tiling for memory efficiency.""" + if hasattr(self.vae, "enable_tiling"): + self.vae.enable_tiling() + + def set_progress_bar_config(self, **kwargs): + """Set progress bar configuration.""" + self._progress_bar_config = kwargs + + def progress_bar(self, iterable=None, **kwargs): + """Create progress bar for iterations.""" + self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {} + config = {**self._progress_bar_config, **kwargs} + return tqdm(iterable, **config) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + return_index: int = -8, + ) -> Tuple[FloatTensor, FloatTensor]: + """Encodes the prompt and negative prompt.""" + if isinstance(prompt, str): + prompt = [prompt] + device = device or self.text_encoder.device + # Text encoder forward pass + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_embeds = self.text_encoder(text_input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds_tensor = prompt_embeds.hidden_states[return_index] + if return_index != -1: + prompt_embeds_tensor = self.text_encoder.encoder.final_layer_norm(prompt_embeds_tensor) + prompt_embeds_tensor = self.text_encoder.encoder.dropout(prompt_embeds_tensor) + + dtype = dtype or next(self.text_encoder.parameters()).dtype + prompt_embeds_tensor = prompt_embeds_tensor.to(dtype=dtype, device=device) + + # Handle negative prompts + if negative_prompt is None: + negative_embeds = torch.zeros_like(prompt_embeds_tensor) + else: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + negative_result = self.encode_prompt( + prompt=negative_prompt, device=device, dtype=dtype, return_index=return_index + ) + negative_embeds = negative_result[0] + + # Explicitly cast both tensors to FloatTensor for mypy + from typing import cast + + prompt_tensor = cast(FloatTensor, prompt_embeds_tensor.to(dtype=dtype)) + negative_tensor = cast(FloatTensor, negative_embeds.to(dtype=dtype)) + return (prompt_tensor, negative_tensor) + + def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False): + """Move pipeline components to specified device and dtype.""" + if hasattr(self, "vae"): + self.vae.to(device=torch_device, dtype=torch_dtype) + if hasattr(self, "text_encoder"): + self.text_encoder.to(device=torch_device, dtype=torch_dtype) + if hasattr(self, "dit_model"): + self.dit_model.to(device=torch_device, dtype=torch_dtype) + return self + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]]=None, + prompt_embeds: Optional[FloatTensor] = None, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + num_inference_steps: int = 30, + guidance_scale: float = 6.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_embeds: Optional[FloatTensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + alpha: Optional[float] = None, + apg_config: Optional[APGConfig] = None, + **kwargs, + ): + """Generate images from text prompt.""" + # Ensure height and width are not None for calculation + if height is None: + height = 1024 + if width is None: + width = 1024 + + dtype = dtype or next(self.dit_model.parameters()).dtype + apg_config = apg_config or APGConfig(enabled=False) + + device = self._execution_device + + # 2. Encode prompts + prompt_batch_size = len(prompt) if isinstance(prompt, list) else 1 + batch_size = prompt_batch_size * num_images_per_prompt + + if prompt_embeds is None or negative_prompt_embeds is None: + prompt_embeds, negative_embeds = self.encode_prompt( + prompt=prompt, negative_prompt=negative_prompt, device=self.text_encoder.device, dtype=dtype, + return_index=self.return_index, + ) + else: + negative_embeds = negative_prompt_embeds + + # Repeat embeddings for num_images_per_prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + # 3. Initialize latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError(f"Got {len(generator)} generators for {batch_size} samples") + + latents = randn_tensor((batch_size, 16, latent_height, latent_width), generator=generator, device=device, dtype=dtype) + acc_latents = latents.clone() + + # 4. Calculate alpha if not provided + if alpha is None: + image_token_size = latent_height * latent_width + alpha = 2 * math.sqrt(image_token_size / (64 * 64)) + + # 6. Sampling loop + self.dit_model.eval() + + # Check if guidance is needed + do_classifier_free_guidance = guidance_scale >= 1.0 + + for i in self.progress_bar(range(num_inference_steps, 0, -1)): + # Calculate timesteps + t = i / num_inference_steps + t_next = (i - 1) / num_inference_steps + # Scale timesteps according to alpha + t = t * alpha / (1 + (alpha - 1) * t) + t_next = t_next * alpha / (1 + (alpha - 1) * t_next) + dt = t - t_next + + # Create tensor with proper device + t_tensor = torch.tensor([t] * batch_size, device=device, dtype=dtype) + + if do_classifier_free_guidance: + # Duplicate latents for both conditional and unconditional inputs + latents_input = torch.cat([latents] * 2) + # Concatenate negative and positive prompt embeddings + context_input = torch.cat([negative_embeds, prompt_embeds]) + # Duplicate timesteps for the batch + t_input = torch.cat([t_tensor] * 2) + + # Get model predictions in a single pass + model_outputs = self.dit_model(latents_input, context_input, t_input) + + # Split outputs back into unconditional and conditional predictions + uncond_output, cond_output = model_outputs.chunk(2) + + if apg_config.enabled: + # Augmented Parallel Guidance + dy = cond_output + dd = cond_output - uncond_output + # Find parallel direction + parallel_direction = (dy * dd).sum() / (dy * dy).sum() * dy + orthogonal_direction = dd - parallel_direction + # Scale orthogonal component + orthogonal_std = orthogonal_direction.std() + orthogonal_scale = min(1, apg_config.orthogonal_threshold / orthogonal_std) + orthogonal_direction = orthogonal_direction * orthogonal_scale + model_output = dy + (guidance_scale - 1) * orthogonal_direction + else: + # Standard classifier-free guidance + model_output = uncond_output + guidance_scale * (cond_output - uncond_output) + else: + # If no guidance needed, just run the model normally + model_output = self.dit_model(latents, prompt_embeds, t_tensor) + + # Update latents + acc_latents = acc_latents + dt * model_output.to(device) + latents = acc_latents.clone() + + # 7. Decode latents + # These checks handle the case where mypy doesn't recognize these attributes + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) if hasattr(self.vae, "config") else 0.18215 + shift_factor = getattr(self.vae.config, "shift_factor", 0) if hasattr(self.vae, "config") else 0 + + latents = latents / scaling_factor + shift_factor + + vae_dtype = self.vae.dtype if hasattr(self.vae, "dtype") else dtype + decoded_images = self.vae.decode(latents.to(vae_dtype)).sample if hasattr(self.vae, "decode") else latents + + # Offload all models + try: + self.maybe_free_model_hooks() + except AttributeError as e: + if "OptimizedModule" in str(e): + import warnings + warnings.warn( + "Encountered 'OptimizedModule' error when offloading models. " + "This issue might be fixed in the future by: " + "https://github.com/huggingface/diffusers/pull/10730" + ) + else: + raise + + # 8. Post-process images + images = (decoded_images / 2 + 0.5).clamp(0, 1) + # Convert to PIL Images + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu() + pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images] + + return FLitePipelineOutput( + images=pil_images, + ) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24ab4d91575ec21ba4825887edc64e01b153d90a --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/__init__.py @@ -0,0 +1,2 @@ +from .flux2_model import Flux2Model +from .flux2_klein_model import Flux2Klein4BModel, Flux2Klein9BModel diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py new file mode 100644 index 0000000000000000000000000000000000000000..864bb41e567e623c74cf56de7e12423f3c2689a9 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py @@ -0,0 +1,91 @@ +from .flux2_model import Flux2Model +from transformers import Qwen3ForCausalLM, Qwen2Tokenizer +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype +from toolkit.config_modules import ModelConfig +from toolkit.memory_management.manager import MemoryManager +from toolkit.basic import flush +from .src.model import Klein9BParams, Klein4BParams + + +class Flux2KleinModel(Flux2Model): + flux2_klein_te_path: str = None + flux2_te_type: str = "qwen" # "mistral" or "qwen" + flux2_vae_path: str = "ai-toolkit/flux2_vae" + flux2_is_guidance_distilled: bool = False + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs, + ) + # use the new format on this new model by default + self.use_old_lokr_format = False + + def load_te(self): + if self.flux2_klein_te_path is None: + raise ValueError("flux2_klein_te_path must be set for Flux2KleinModel") + dtype = self.torch_dtype + self.print_and_status_update("Loading Qwen3") + + text_encoder: Qwen3ForCausalLM = Qwen3ForCausalLM.from_pretrained( + self.flux2_klein_te_path, + torch_dtype=dtype, + ) + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Qwen3") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + elif not self.model_config.low_vram: + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path) + return text_encoder, tokenizer + + +class Flux2Klein4BModel(Flux2KleinModel): + arch = "flux2_klein_4b" + flux2_klein_te_path: str = "Qwen/Qwen3-4B" + flux2_te_filename: str = "flux-2-klein-base-4b.safetensors" + + def get_flux2_params(self): + return Klein4BParams() + + def get_base_model_version(self): + return "flux2_klein_4b" + + +class Flux2Klein9BModel(Flux2KleinModel): + arch = "flux2_klein_9b" + flux2_klein_te_path: str = "Qwen/Qwen3-8B" + flux2_te_filename: str = "flux-2-klein-base-9b.safetensors" + + def get_flux2_params(self): + return Klein9BParams() + + def get_base_model_version(self): + return "flux2_klein_9b" diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_model.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e22a1cd7eda325d031960fe1a5d1c2cbd8d70610 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/flux2_model.py @@ -0,0 +1,552 @@ +import math +import os +from typing import TYPE_CHECKING, List, Optional + +import huggingface_hub +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.memory_management.manager import MemoryManager +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model + +from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from .src.model import Flux2, Flux2Params +from .src.pipeline import Flux2Pipeline +from .src.autoencoder import AutoEncoder, AutoEncoderParams, AutoEncoderSmallDecoderParams +from safetensors.torch import load_file, save_file +from PIL import Image +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +from .src.sampling import ( + batched_prc_img, + batched_prc_txt, + encode_image_refs, + scatter_ids, +) + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True, +} + +MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" +FLUX2_VAE_FILENAME = "ae.safetensors" +FLUX2_TRANSFORMER_FILENAME = "flux2-dev.safetensors" + +HF_TOKEN = os.getenv("HF_TOKEN", None) + + +class Flux2Model(BaseModel): + arch = "flux2" + flux2_te_type: str = "mistral" # "mistral" or "qwen" + flux2_vae_path: str = None + flux2_te_filename: str = FLUX2_TRANSFORMER_FILENAME + flux2_is_guidance_distilled: bool = True + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["Flux2"] + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = True + # do not resize control images + self.use_raw_control_images = True + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def get_flux2_params(self): + return Flux2Params() + + def load_te(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Mistral") + + text_encoder: Mistral3ForConditionalGeneration = ( + Mistral3ForConditionalGeneration.from_pretrained( + MISTRAL_PATH, + torch_dtype=dtype, + ) + ) + text_encoder.to(self.device_torch, dtype=dtype) + + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Mistral") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH) + return text_encoder, tokenizer + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Flux2 model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + transformer_path = model_path + + self.print_and_status_update("Loading transformer") + with torch.device("meta"): + transformer = Flux2(self.get_flux2_params()) + + # use local path if provided + if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)): + transformer_path = os.path.join(transformer_path, self.flux2_te_filename) + + if not os.path.exists(transformer_path): + # assume it is from the hub + transformer_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=self.flux2_te_filename, + token=HF_TOKEN, + ) + + transformer_state_dict = load_file(transformer_path, device="cpu") + + # cast to dtype + for key in transformer_state_dict: + transformer_state_dict[key] = transformer_state_dict[key].to(dtype) + + transformer.load_state_dict(transformer_state_dict, assign=True) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + # Avoid full-model peak VRAM allocation before quantization. + self.print_and_status_update("Keeping transformer on CPU for quantization") + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + text_encoder, tokenizer = self.load_te() + + self.print_and_status_update("Loading VAE") + vae_path = self.model_config.vae_path + + if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)): + vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME) + + if vae_path is None: + vae_path = self.flux2_vae_path + + if vae_path is None or not os.path.exists(vae_path): + vae_filename = FLUX2_VAE_FILENAME + if vae_path is not None: + # see if it is a filename for huggingface hub + if len(vae_path.split("/")) == 3 and vae_path.endswith(".safetensors"): + vae_filename = vae_path.split("/")[-1] + vae_path = "/".join(vae_path.split("/")[:-1]) + p = vae_path if vae_path is not None else model_path + # assume it is from the hub + vae_path = huggingface_hub.hf_hub_download( + repo_id=p, + filename=vae_filename, + token=HF_TOKEN, + ) + + vae_state_dict = load_file(vae_path, device="cpu") + + autoencoder_params = AutoEncoderParams() + if vae_state_dict['decoder.up.0.block.0.conv1.bias'].shape[0] == 96: + # this is the small decoder version + autoencoder_params = AutoEncoderSmallDecoderParams() + + with torch.device("meta"): + vae = AutoEncoder(autoencoder_params) + + # cast to dtype + for key in vae_state_dict: + vae_state_dict[key] = vae_state_dict[key].to(dtype) + + vae.load_state_dict(vae_state_dict, assign=True) + + self.noise_scheduler = Flux2Model.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: Flux2Pipeline = Flux2Pipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + transformer=None, + text_encoder_type=self.flux2_te_type, + is_guidance_distilled=self.flux2_is_guidance_distilled, + ) + # for quantization, it works best to do these after making the pipe + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + flush() + # just to make sure everything is on the right device and dtype + if self.model_config.low_vram: + text_encoder[0].to("cpu") + else: + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + if self.model_config.low_vram: + pipe.transformer = pipe.transformer.to("cpu") + else: + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = Flux2Model.get_train_scheduler() + + pipeline: Flux2Pipeline = Flux2Pipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + text_encoder_type=self.flux2_te_type, + is_guidance_distilled=self.flux2_is_guidance_distilled, + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: Flux2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + gen_config.width = ( + gen_config.width // self.get_bucket_divisibility() + ) * self.get_bucket_divisibility() + gen_config.height = ( + gen_config.height // self.get_bucket_divisibility() + ) * self.get_bucket_divisibility() + + control_img_list = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + elif gen_config.ctrl_img_1 is not None: + control_img = Image.open(gen_config.ctrl_img_1) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_2 is not None: + control_img = Image.open(gen_config.ctrl_img_2) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_3 is not None: + control_img = Image.open(gen_config.ctrl_img_3) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + if not self.flux2_is_guidance_distilled: + extra["negative_prompt_embeds"] = unconditional_embeds.text_embeds + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + control_img_list=control_img_list, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + guidance_embedding_scale: float, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + txt, txt_ids = batched_prc_txt(text_embeddings.text_embeds) + packed_latents, img_ids = batched_prc_img(latent_model_input) + + # prepare image conditioning if any + img_cond_seq: torch.Tensor | None = None + img_cond_seq_ids: torch.Tensor | None = None + + # handle control images + batch_control_tensor_list = batch.control_tensor_list + if batch_control_tensor_list is None and batch.control_tensor is not None: + batch_control_tensor_list = [] + for b in range(latent_model_input.shape[0]): + batch_control_tensor_list.append(batch.control_tensor[b : b + 1]) + + if batch_control_tensor_list is not None: + batch_size, num_channels_latents, height, width = ( + latent_model_input.shape + ) + + control_image_max_res = 1024 * 1024 + if self.model_config.model_kwargs.get("match_target_res", False): + # use the current target size to set the control image res + control_image_res = ( + height + * self.pipeline.vae_scale_factor + * width + * self.pipeline.vae_scale_factor + ) + control_image_max_res = control_image_res + + if len(batch_control_tensor_list) != batch_size: + raise ValueError( + "Control tensor list length does not match batch size" + ) + for control_tensor_list in batch_control_tensor_list: + # control tensor list is a list of tensors for this batch item + controls = [] + # pack control + for control_img in control_tensor_list: + # control images are 0 - 1 scale, shape (1, ch, height, width) + control_img = control_img.to( + self.device_torch, dtype=self.torch_dtype + ) + # if it is only 3 dim, add batch dim + if len(control_img.shape) == 3: + control_img = control_img.unsqueeze(0) + + # resize to fit within max res while keeping aspect ratio + if self.model_config.model_kwargs.get( + "match_target_res", False + ): + ratio = control_img.shape[2] / control_img.shape[3] + c_height = math.sqrt(control_image_res * ratio) + c_width = c_height / ratio + + c_width = round(c_width / 32) * 32 + c_height = round(c_height / 32) * 32 + + control_img = F.interpolate( + control_img, size=(c_height, c_width), mode="bilinear" + ) + + # scale to -1 to 1 + control_img = control_img * 2 - 1 + controls.append(control_img) + + if self.vae.device == torch.device("cpu"): + self.vae.to(self.device_torch) + img_cond_seq_item, img_cond_seq_ids_item = encode_image_refs( + self.vae, controls, limit_pixels=control_image_max_res + ) + if img_cond_seq is None: + img_cond_seq = img_cond_seq_item + img_cond_seq_ids = img_cond_seq_ids_item + else: + img_cond_seq = torch.cat( + (img_cond_seq, img_cond_seq_item), dim=0 + ) + img_cond_seq_ids = torch.cat( + (img_cond_seq_ids, img_cond_seq_ids_item), dim=0 + ) + + img_input = packed_latents + img_input_ids = img_ids + + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) + img_input = torch.cat((img_input, img_cond_seq.to(img_input.device, img_input.dtype)), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids.to(img_input_ids.device)), dim=1) + + guidance_vec = torch.full( + (img_input.shape[0],), + guidance_embedding_scale, + device=img_input.device, + dtype=img_input.dtype, + ) + + cast_dtype = self.model.dtype + + packed_noise_pred = self.transformer( + x=img_input.to(self.device_torch, cast_dtype), + x_ids=img_input_ids.to(self.device_torch), + timesteps=timestep.to(self.device_torch, cast_dtype) / 1000, + ctx=txt.to(self.device_torch, cast_dtype), + ctx_ids=txt_ids.to(self.device_torch), + guidance=guidance_vec.to(self.device_torch, cast_dtype), + ) + + if img_cond_seq is not None: + packed_noise_pred = packed_noise_pred[:, : packed_latents.shape[1]] + + if isinstance(packed_noise_pred, QTensor): + packed_noise_pred = packed_noise_pred.dequantize() + + noise_pred = torch.cat(scatter_ids(packed_noise_pred, img_ids)).squeeze(2) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, device=self.device_torch + ) + pe = PromptEmbeds(prompt_embeds) + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: Flux2 = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name="flux2") + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "flux2" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["double_blocks", "single_blocks"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + images = torch.stack(image_list).to(device, dtype=dtype) + + latents = self.vae.encode(images) + + return latents + + def decode_latents(self, latents, device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + latents = latents.to(device, dtype=dtype) + + images = self.vae.decode(latents) + + return images diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/autoencoder.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8230e5a1c2be24111b8f38c75bab72c6c47e80 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/autoencoder.py @@ -0,0 +1,435 @@ +from dataclasses import dataclass, field + +import torch +from einops import rearrange +from torch import Tensor, nn +import math +import torch.utils.checkpoint as ckpt + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) + num_res_blocks: int = 2 + z_channels: int = 32 + +@dataclass +class AutoEncoderSmallDecoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + ch_encoder: int = 96 + out_ch: int = 3 + ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) + num_res_blocks: int = 2 + z_channels: int = 32 + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.down[i_level].block[i_block], hs[-1]) + if len(self.down[i_level].attn) > 0: + h = ckpt.checkpoint(self.down[i_level].attn[i_block], h) + else: + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hs.append(ckpt.checkpoint(self.down[i_level].downsample, hs[-1])) + else: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.mid.block_1, h) + h = ckpt.checkpoint(self.mid.attn_1, h) + h = ckpt.checkpoint(self.mid.block_2, h) + else: + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h = self.quant_conv(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def forward(self, z: Tensor) -> Tensor: + z = self.post_quant_conv(z) + + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.mid.block_1, h) + h = ckpt.checkpoint(self.mid.attn_1, h) + h = ckpt.checkpoint(self.mid.block_2, h) + else: + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.up[i_level].block[i_block], h) + if len(self.up[i_level].attn) > 0: + h = ckpt.checkpoint(self.up[i_level].attn[i_block], h) + else: + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.up[i_level].upsample, h) + else: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + decoder_ch = params.ch + if hasattr(params, "ch_encoder"): + decoder_ch = params.ch_encoder + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=decoder_ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d( + math.prod(self.ps) * params.z_channels, + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self._gradient_checkpointing = False + + @property + def gradient_checkpointing(self): + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, value: bool): + self._gradient_checkpointing = value + self.encoder.gradient_checkpointing = value + self.decoder.gradient_checkpointing = value + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + self.encoder.enable_gradient_checkpointing() + self.decoder.enable_gradient_checkpointing() + + def normalize(self, z): + self.bn.eval() + return self.bn(z) + + def inv_normalize(self, z): + self.bn.eval() + s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) + m = self.bn.running_mean.view(1, -1, 1, 1) + return z * s + m + + def encode(self, x: Tensor) -> Tensor: + moments = self.encoder(x) + mean = torch.chunk(moments, 2, dim=1)[0] + + z = rearrange( + mean, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + z = self.normalize(z) + return z + + def decode(self, z: Tensor) -> Tensor: + z = self.inv_normalize(z) + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + dec = self.decoder(z) + return dec diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/model.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cffae9f61dcc906b9c2096b115c7b3405f3b20cd --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/model.py @@ -0,0 +1,552 @@ +import torch +from einops import rearrange +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt +import math +from dataclasses import dataclass, field + + +@dataclass +class Flux2Params: + in_channels: int = 128 + context_in_dim: int = 15360 + hidden_size: int = 6144 + num_heads: int = 48 + depth: int = 8 + depth_single_blocks: int = 48 + axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) + theta: int = 2000 + mlp_ratio: float = 3.0 + use_guidance_embed: bool = True + + +@dataclass +class Klein9BParams: + in_channels: int = 128 + context_in_dim: int = 12288 + hidden_size: int = 4096 + num_heads: int = 32 + depth: int = 8 + depth_single_blocks: int = 24 + axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) + theta: int = 2000 + mlp_ratio: float = 3.0 + use_guidance_embed: bool = False + + +@dataclass +class Klein4BParams: + in_channels: int = 128 + context_in_dim: int = 7680 + hidden_size: int = 3072 + num_heads: int = 24 + depth: int = 5 + depth_single_blocks: int = 20 + axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32]) + theta: int = 2000 + mlp_ratio: float = 3.0 + use_guidance_embed: bool = False + + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.patch_size = 1 + + +class Flux2(nn.Module): + def __init__(self, params: Flux2Params): + super().__init__() + self.config = FakeConfig() + + self.in_channels = params.in_channels + self.out_channels = params.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False) + self.time_in = MLPEmbedder( + in_dim=256, hidden_dim=self.hidden_size, disable_bias=True + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False) + + self.use_guidance_embed = params.use_guidance_embed + if self.use_guidance_embed: + self.guidance_in = MLPEmbedder( + in_dim=256, hidden_dim=self.hidden_size, disable_bias=True + ) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + disable_bias=True, + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + disable_bias=True, + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, disable_bias=True + ) + + self.final_layer = LastLayer( + self.hidden_size, + self.out_channels, + ) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def forward( + self, + x: Tensor, + x_ids: Tensor, + timesteps: Tensor, + ctx: Tensor, + ctx_ids: Tensor, + guidance: Tensor | None, + ): + num_txt_tokens = ctx.shape[1] + + timestep_emb = timestep_embedding(timesteps, 256) + vec = self.time_in(timestep_emb) + if self.use_guidance_embed: + guidance_emb = timestep_embedding(guidance, 256) + vec = vec + self.guidance_in(guidance_emb) + + double_block_mod_img = self.double_stream_modulation_img(vec) + double_block_mod_txt = self.double_stream_modulation_txt(vec) + single_block_mod, _ = self.single_stream_modulation(vec) + + img = self.img_in(x) + txt = self.txt_in(ctx) + + pe_x = self.pe_embedder(x_ids) + pe_ctx = self.pe_embedder(ctx_ids) + + for block in self.double_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img, txt = ckpt.checkpoint( + block, + img, + txt, + pe_x, + pe_ctx, + double_block_mod_img, + double_block_mod_txt, + use_reentrant=False, + ) + else: + img, txt = block( + img, + txt, + pe_x, + pe_ctx, + double_block_mod_img, + double_block_mod_txt, + ) + + img = torch.cat((txt, img), dim=1) + pe = torch.cat((pe_ctx, pe_x), dim=2) + + for i, block in enumerate(self.single_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + img = ckpt.checkpoint( + block, + img, + pe, + single_block_mod, + use_reentrant=False, + ) + else: + img = block( + img, + pe, + single_block_mod, + ) + + img = img[:, num_txt_tokens:, ...] + + img = self.final_layer(img, vec) + return img + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=False) + + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim, bias=False) + + +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool, disable_bias: bool = False): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias) + + def forward(self, vec: torch.Tensor): + out = self.lin(nn.functional.silu(vec)) + if out.ndim == 2: + out = out[:, None, :] + out = out.chunk(self.multiplier, dim=-1) + return out[:3], out[3:] if self.is_double else None + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + out_channels: int, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=False) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False) + ) + + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + mod = self.adaLN_modulation(vec) + shift, scale = mod.chunk(2, dim=-1) + if shift.ndim == 2: + shift = shift[:, None, :] + scale = scale[:, None, :] + x = (1 + scale) * self.norm_final(x) + shift + x = self.linear(x) + return x + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + ): + super().__init__() + + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = head_dim**-0.5 + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_mult_factor = 2 + + self.linear1 = nn.Linear( + hidden_size, + hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=False, + ) + + self.linear2 = nn.Linear( + hidden_size + self.mlp_hidden_dim, hidden_size, bias=False + ) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = SiLUActivation() + + def forward( + self, + x: Tensor, + pe: Tensor, + mod: tuple[Tensor, Tensor], + ) -> Tensor: + mod_shift, mod_scale, mod_gate = mod + x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift + + qkv, mlp = torch.split( + self.linear1(x_mod), + [3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + attn = attention(q, k, v, pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod_gate * output + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + ): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + assert hidden_size % num_heads == 0, ( + f"{hidden_size=} must be divisible by {num_heads=}" + ) + + self.hidden_size = hidden_size + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp_mult_factor = 2 + + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False), + SiLUActivation(), + nn.Linear(mlp_hidden_dim, hidden_size, bias=False), + ) + + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear( + hidden_size, + mlp_hidden_dim * self.mlp_mult_factor, + bias=False, + ), + SiLUActivation(), + nn.Linear(mlp_hidden_dim, hidden_size, bias=False), + ) + + def forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + pe_ctx: Tensor, + mod_img: tuple[Tensor, Tensor], + mod_txt: tuple[Tensor, Tensor], + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = mod_img + txt_mod1, txt_mod2 = mod_txt + + img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1 + img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2 + txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1 + txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2 + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift + + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift + + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + pe = torch.cat((pe_ctx, pe), dim=2) + attn = attention(q, k, v, pe) + txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :] + + # calculate the img blocks + img = img + img_mod1_gate * self.img_attn.proj(img_attn) + img = img + img_mod2_gate * self.img_mlp( + (1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift + ) + + # calculate the txt blocks + txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2_gate * self.txt_mlp( + (1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift + ) + return img, txt + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + emb = torch.cat( + [ + rope(ids[..., i], self.axes_dim[i], self.theta) + for i in range(len(self.axes_dim)) + ], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) + / half + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8e64639349524e9f14b26de3adc1a08f6003525b --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/pipeline.py @@ -0,0 +1,452 @@ +from typing import List, Optional, Union + +import numpy as np +import torch +import PIL.Image +from dataclasses import dataclass +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput +from .autoencoder import AutoEncoder +from .model import Flux2 +from einops import rearrange +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from .sampling import ( + get_schedule, + batched_prc_img, + batched_prc_txt, + encode_image_refs, + scatter_ids, +) + + +@dataclass +class Flux2ImagePipelineOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""" +OUTPUT_LAYERS_MISTRAL = [10, 20, 30] +OUTPUT_LAYERS_QWEN3 = [9, 18, 27] +MAX_LENGTH = 512 + + +class Flux2Pipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoEncoder, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2, + text_encoder_type: str = "mistral", # "mistral" or "qwen" + is_guidance_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 16 # 8x plus 2x pixel shuffle + self.num_channels_latents = 128 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 64 + self.text_encoder_type = text_encoder_type + self.is_guidance_distilled = is_guidance_distilled + + def format_input( + self, + txt: list[str], + ) -> list[list[dict]]: + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": SYSTEM_MESSAGE}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + def _get_mistral_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + # Format input messages + messages_batch = self.format_input(txt=prompt) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + try: + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + except ValueError as e: + print( + f"Error processing input: {e}, your max length is probably too short, when you have images in the input." + ) + raise e + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack( + [output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1 + ) + prompt_embeds = rearrange(out, "b c l d -> b l (c d)") + + # they don't return attention mask, so we create it here + return prompt_embeds, None + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + all_input_ids = [] + all_attention_masks = [] + + for p in prompt: + messages = [{"role": "user", "content": p}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + model_inputs = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(model_inputs["input_ids"]) + all_attention_masks.append(model_inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1) + prompt_embeds = rearrange(out, "b c l d -> b l (c d)") + + # they dont use attention mask + return prompt_embeds, None + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + if self.text_encoder_type == "mistral": + prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) + elif self.text_encoder_type == "qwen": + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) + else: + raise ValueError( + f"Unsupported text_encoder_type: {self.text_encoder_type}" + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + return prompt_embeds, prompt_embeds_mask + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + max_sequence_length: int = 512, + control_img_list: Optional[List[PIL.Image.Image]] = None, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + do_guidance = ( + guidance_scale is not None + and guidance_scale > 1.0 + and not self.is_guidance_distilled + ) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode the prompt + + prompt_embeds, _ = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + txt, txt_ids = batched_prc_txt(prompt_embeds) + neg_txt, neg_txt_ids = None, None + + if do_guidance: + negative_prompt_embeds, _ = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + neg_txt, neg_txt_ids = batched_prc_txt(negative_prompt_embeds) + + # 4. Prepare latent variables\ + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + packed_latents, img_ids = batched_prc_img(latents) + + timesteps = get_schedule(num_inference_steps, packed_latents.shape[1]) + + self._num_timesteps = len(timesteps) + + guidance_vec = torch.full( + (packed_latents.shape[0],), + guidance_scale, + device=packed_latents.device, + dtype=packed_latents.dtype, + ) + + if control_img_list is not None and len(control_img_list) > 0: + img_cond_seq, img_cond_seq_ids = encode_image_refs( + self.vae, control_img_list + ) + else: + img_cond_seq, img_cond_seq_ids = None, None + + # 6. Denoising loop + i = 0 + with self.progress_bar(total=num_inference_steps) as progress_bar: + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + if self.interrupt: + continue + t_vec = torch.full( + (packed_latents.shape[0],), + t_curr, + dtype=packed_latents.dtype, + device=packed_latents.device, + ) + + self._current_timestep = t_curr + img_input = packed_latents + img_input_ids = img_ids + + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + + pred = self.transformer( + x=img_input, + x_ids=img_input_ids, + timesteps=t_vec, + ctx=txt, + ctx_ids=txt_ids, + guidance=guidance_vec, + ) + + if do_guidance: + pred_uncond = self.transformer( + x=img_input, + x_ids=img_input_ids, + timesteps=t_vec, + ctx=neg_txt, + ctx_ids=neg_txt_ids, + guidance=guidance_vec, + ) + pred = pred_uncond + guidance_scale * (pred - pred_uncond) + + if img_cond_seq is not None: + pred = pred[:, : packed_latents.shape[1]] + + packed_latents = packed_latents + (t_prev - t_curr) * pred + i += 1 + progress_bar.update(1) + + self._current_timestep = None + + # 7. Post-processing + latents = torch.cat(scatter_ids(packed_latents, img_ids)).squeeze(2) + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + image = self.vae.decode(latents).float() + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2ImagePipelineOutput(images=image) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/sampling.py b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..b02af2d67c0c6eb8dd3be31d36cb6f930168e55b --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux2/src/sampling.py @@ -0,0 +1,365 @@ +import math +from typing import Callable, Union + +import torch +from einops import rearrange +from PIL import Image +from torch import Tensor + +from .model import Flux2 +import torchvision + + +def compress_time(t_ids: Tensor) -> Tensor: + assert t_ids.ndim == 1 + t_ids_max = torch.max(t_ids) + t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) + t_unique_sorted_ids = torch.unique(t_ids, sorted=True) + t_remap[t_unique_sorted_ids] = torch.arange( + len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype + ) + t_ids_compressed = t_remap[t_ids] + return t_ids_compressed + + +def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + t_coords = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + t_ids = pos[:, 0].to(torch.int64) + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + t_ids_cmpr = compress_time(t_ids) + + t = torch.max(t_ids_cmpr) + 1 + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids + + out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) + t_coords.append(torch.unique(t_ids, sorted=True)) + return x_list + + +def encode_image_refs( + ae, + img_ctx: Union[list[Image.Image], list[torch.Tensor]], + scale=10, + limit_pixels=1024**2, +): + if not img_ctx: + return None, None + + img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels) + if not isinstance(img_ctx_prep, list): + img_ctx_prep = [img_ctx_prep] + + # Encode each reference image + encoded_refs = [] + for img in img_ctx_prep: + if img.ndim == 3: + img = img.unsqueeze(0) + encoded = ae.encode(img.to(ae.device, ae.dtype))[0] + encoded_refs.append(encoded) + + # Create time offsets for each reference + t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))] + t_off = [t.view(-1) for t in t_off] + + # Process with position IDs + ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off) + + # Concatenate all references along sequence dimension + ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C) + ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4) + + # Add batch dimension + ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C) + ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4) + + return ref_tokens.to(torch.bfloat16), ref_ids + + +def prc_txt( + x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None +) -> tuple[Tensor, Tensor]: + assert l_coord is None, "l_coord not supported for txts" + + _l, _ = x.shape # noqa: F841 + + coords = { + "t": torch.arange(1) if t_coord is None else t_coord, + "h": torch.arange(1), # dummy dimension + "w": torch.arange(1), # dummy dimension + "l": torch.arange(_l), + } + x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) + return x, x_ids.to(x.device) + + +def batched_wrapper(fn): + def batched_prc( + x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None + ) -> tuple[Tensor, Tensor]: + results = [] + for i in range(len(x)): + results.append( + fn( + x[i], + t_coord[i] if t_coord is not None else None, + l_coord[i] if l_coord is not None else None, + ) + ) + x, x_ids = zip(*results) + return torch.stack(x), torch.stack(x_ids) + + return batched_prc + + +def listed_wrapper(fn): + def listed_prc( + x: list[Tensor], + t_coord: list[Tensor] | None = None, + l_coord: list[Tensor] | None = None, + ) -> tuple[list[Tensor], list[Tensor]]: + results = [] + for i in range(len(x)): + results.append( + fn( + x[i], + t_coord[i] if t_coord is not None else None, + l_coord[i] if l_coord is not None else None, + ) + ) + x, x_ids = zip(*results) + return list(x), list(x_ids) + + return listed_prc + + +def prc_img( + x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None +) -> tuple[Tensor, Tensor]: + c, h, w = x.shape # noqa: F841 + x_coords = { + "t": torch.arange(1) if t_coord is None else t_coord, + "h": torch.arange(h), + "w": torch.arange(w), + "l": torch.arange(1) if l_coord is None else l_coord, + } + x_ids = torch.cartesian_prod( + x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"] + ) + x = rearrange(x, "c h w -> (h w) c") + return x, x_ids.to(x.device) + + +listed_prc_img = listed_wrapper(prc_img) +batched_prc_img = batched_wrapper(prc_img) +batched_prc_txt = batched_wrapper(prc_txt) + + +def center_crop_to_multiple_of_x( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], x: int +) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]: + if isinstance(img, list): + return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore + + if isinstance(img, torch.Tensor): + h, w = img.shape[-2], img.shape[-1] + else: + w, h = img.size + new_w = (w // x) * x + new_h = (h // x) * x + + left = (w - new_w) // 2 + top = (h - new_h) // 2 + right = left + new_w + bottom = top + new_h + + if isinstance(img, torch.Tensor): + return img[..., top:bottom, left:right] + resized = img.crop((left, top, right, bottom)) + return resized + + +def cap_pixels( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], k +): + if isinstance(img, list): + return [cap_pixels(_img, k) for _img in img] + if isinstance(img, torch.Tensor): + h, w = img.shape[-2], img.shape[-1] + else: + w, h = img.size + pixel_count = w * h + + if pixel_count <= k: + return img + + # Scaling factor to reduce total pixels below K + scale = math.sqrt(k / pixel_count) + new_w = int(w * scale) + new_h = int(h * scale) + + if isinstance(img, torch.Tensor): + did_expand = False + if img.ndim == 3: + img = img.unsqueeze(0) + did_expand = True + img = torch.nn.functional.interpolate( + img, + size=(new_h, new_w), + mode="bicubic", + align_corners=False, + ) + if did_expand: + img = img.squeeze(0) + return img + return img.resize((new_w, new_h), Image.Resampling.LANCZOS) + + +def cap_min_pixels( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], + max_ar=8, + min_sidelength=64, +): + if isinstance(img, list): + return [ + cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) + for _img in img + ] + if isinstance(img, torch.Tensor): + h, w = img.shape[-2], img.shape[-1] + else: + w, h = img.size + if w < min_sidelength or h < min_sidelength: + raise ValueError( + f"Skipping due to minimal sidelength underschritten h {h} w {w}" + ) + if w / h > max_ar or h / w > max_ar: + raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}") + return img + + +def to_rgb( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], +) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]: + if isinstance(img, list): + return [ + to_rgb( + _img, + ) + for _img in img + ] + if isinstance(img, torch.Tensor): + return img # assume already in tensor format + return img.convert("RGB") + + +def default_images_prep( + x: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], +) -> torch.Tensor | list[torch.Tensor]: + if isinstance(x, list): + return [default_images_prep(e) for e in x] # type: ignore + if isinstance(x, torch.Tensor): + return x # assume already in tensor format + x_tensor = torchvision.transforms.ToTensor()(x) + return 2 * x_tensor - 1 + + +def default_prep( + img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], + limit_pixels: int, + ensure_multiple: int = 16, +) -> torch.Tensor | list[torch.Tensor]: + # if passing a tensor, assume it is -1 to 1 already + img_rgb = to_rgb(img) + img_min = cap_min_pixels(img_rgb) # type: ignore + img_cap = cap_pixels(img_min, limit_pixels) # type: ignore + img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore + img_tensor = default_images_prep(img_crop) + return img_tensor + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux2, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float, + # extra img tokens (sequence-wise) + img_cond_seq: Tensor | None = None, + img_cond_seq_ids: Tensor | None = None, +): + guidance_vec = torch.full( + (img.shape[0],), guidance, device=img.device, dtype=img.dtype + ) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + img_input = img + img_input_ids = img_ids + if img_cond_seq is not None: + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + pred = model( + x=img_input, + x_ids=img_input_ids, + timesteps=t_vec, + ctx=txt, + ctx_ids=txt_ids, + guidance=guidance_vec, + ) + if img_input_ids is not None: + pred = pred[:, : img.shape[1]] + + img = img + (t_prev - t_curr) * pred + + return img diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux_kontext/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/flux_kontext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..081c949addccddef8874d60d5a0675582c422453 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux_kontext/__init__.py @@ -0,0 +1 @@ +from .flux_kontext import FluxKontextModel diff --git a/ai-toolkit/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py b/ai-toolkit/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py new file mode 100644 index 0000000000000000000000000000000000000000..e9eeee57ddddba5b40754d17cfb0e000830a6242 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py @@ -0,0 +1,420 @@ +import os +from typing import TYPE_CHECKING, List + +import torch +import torchvision +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from diffusers import FluxTransformer2DModel, AutoencoderKL, FluxKontextPipeline +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.mask import generate_random_mask, random_dialate_mask +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from einops import rearrange, repeat +import random +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + + + +class FluxKontextModel(BaseModel): + arch = "flux_kontext" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['FluxTransformer2DModel'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Flux Kontext model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + # this is the original path put in the model directory + # it is here because for finetuning we only save the transformer usually + # so we need this for the VAE, te, etc + base_model_path = self.model_config.extras_name_or_path + + transformer_path = model_path + transformer_subfolder = 'transformer' + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading transformer") + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=transformer_subfolder, + torch_dtype=dtype + ) + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + base_model_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + base_model_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + self.print_and_status_update("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + + self.noise_scheduler = FluxKontextModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: FluxKontextPipeline = FluxKontextPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = FluxKontextModel.get_train_scheduler() + + pipeline: FluxKontextPipeline = FluxKontextPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer) + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: FluxKontextPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if gen_config.ctrl_img is None: + raise ValueError( + "Control image is required for Flux Kontext model generation." + ) + else: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + gen_config.width = int(gen_config.width // 16 * 16) + gen_config.height = int(gen_config.height // 16 * 16) + img = pipeline( + image=control_img, + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + max_area=gen_config.height * gen_config.width, + _auto_resize=False, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + guidance_embedding_scale: float, + bypass_guidance_embedding: bool, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + # if we have a control on the channel dimension, put it on the batch for packing + has_control = False + if latent_model_input.shape[1] == 32: + # chunk it and stack it on batch dimension + # dont update batch size for img_its + lat, control = torch.chunk(latent_model_input, 2, dim=1) + latent_model_input = torch.cat([lat, control], dim=0) + has_control = True + + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", + b=bs).to(self.device_torch) + + # handle control image ids + if has_control: + ctrl_ids = img_ids.clone() + ctrl_ids[..., 0] = 1 + img_ids = torch.cat([img_ids, ctrl_ids], dim=1) + + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet_unwrapped.config.guidance_embeds: + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor( + guidance_embedding_scale, device=self.device_torch) + else: + guidance = torch.tensor( + [guidance_embedding_scale], device=self.device_torch) + # Expand guidance to match original batch_size + guidance = guidance.expand(bs) + else: + guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) + + cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + latent_size = latent_model_input_packed.shape[1] + # move the kontext channels. We have them on batch dimension to here, but need to put them on the latent dimension + if has_control: + latent, control = torch.chunk(latent_model_input_packed, 2, dim=0) + latent_model_input_packed = torch.cat( + [latent, control], dim=1 + ) + latent_size = latent.shape[1] + + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to( + self.device_torch, cast_dtype), + timestep=timestep / 1000, + encoder_hidden_states=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype), + pooled_projections=text_embeddings.pooled_embeds.to( + self.device_torch, cast_dtype), + txt_ids=txt_ids, + img_ids=img_ids, + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + # remove kontext image conditioning + noise_pred = noise_pred[:, :latent_size] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=self.vae.config.latent_channels + ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, + self.text_encoder, + prompt, + max_length=512, + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + # return from a weight if it has grad + return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: FluxTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + self.vae.to(self.device_torch) + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if control_tensor.shape[2] != target_h or control_tensor.shape[3] != target_w: + control_tensor = F.interpolate(control_tensor, size=(target_h, target_w), mode='bilinear') + + control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype) + latents = torch.cat((latents, control_latent), dim=1) + + return latents.detach() + + def get_base_model_version(self): + return "flux.1_kontext" \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..443d1b0adc12367c81bf6043f0e3a699450016f8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/__init__.py @@ -0,0 +1,3 @@ +from .hidream_model import HidreamModel +from .hidream_e1_model import HidreamE1Model +from .hidream_o1_model import HidreamO1Model diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_e1_model.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_e1_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5306ad516ff0cf32068e4c985700c6c7f42cb711 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_e1_model.py @@ -0,0 +1,189 @@ +from .hidream_model import HidreamModel +from .src.pipelines.hidream_image.pipeline_hidream_image_editing import ( + HiDreamImageEditingPipeline, +) +from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler +from toolkit.accelerator import unwrap_model +import torch +from toolkit.prompt_utils import PromptEmbeds +from toolkit.config_modules import GenerateImageConfig +from diffusers.models import HiDreamImageTransformer2DModel + +import torch.nn.functional as F +from PIL import Image +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + + +class HidreamE1Model(HidreamModel): + arch = "hidream_e1" + hidream_transformer_class = HiDreamImageTransformer2DModel + hidream_pipeline_class = HiDreamImageEditingPipeline + + def get_generation_pipeline(self): + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False + ) + + pipeline: HiDreamImageEditingPipeline = HiDreamImageEditingPipeline( + scheduler=scheduler, + vae=self.vae, + text_encoder=self.text_encoder[0], + tokenizer=self.tokenizer[0], + text_encoder_2=self.text_encoder[1], + tokenizer_2=self.tokenizer[1], + text_encoder_3=self.text_encoder[2], + tokenizer_3=self.tokenizer[2], + text_encoder_4=self.text_encoder[3], + tokenizer_4=self.tokenizer[3], + transformer=unwrap_model(self.model), + aggressive_unloading=self.low_vram, + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: HiDreamImageEditingPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if gen_config.ctrl_img is None: + raise ValueError( + "Control image is required for Flux Kontext model generation." + ) + else: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + img = pipeline( + prompt_embeds_t5=conditional_embeds.text_embeds[0], + prompt_embeds_llama3=conditional_embeds.text_embeds[1], + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds_t5=unconditional_embeds.text_embeds[0], + negative_prompt_embeds_llama3=unconditional_embeds.text_embeds[1], + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + image=control_img, + **extra, + ).images[0] + return img + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) + max_sequence_length = 128 + ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt, + prompt_3=prompt, + prompt_4=prompt, + device=self.device_torch, + dtype=self.torch_dtype, + num_images_per_prompt=1, + max_sequence_length=max_sequence_length, + do_classifier_free_guidance=False, + ) + prompt_embeds = [prompt_embeds_t5, prompt_embeds_llama3] + pe = PromptEmbeds([prompt_embeds, pooled_prompt_embeds]) + return pe + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + self.vae.to(self.device_torch) + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to( + self.vae_device_torch, dtype=self.torch_dtype + ) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if ( + control_tensor.shape[2] != target_h + or control_tensor.shape[3] != target_w + ): + control_tensor = F.interpolate( + control_tensor, size=(target_h, target_w), mode="bilinear" + ) + + control_latent = self.encode_images(control_tensor).to( + latents.device, latents.dtype + ) + latents = torch.cat((latents, control_latent), dim=1) + + return latents.detach() + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + with torch.no_grad(): + # make sure config is set + self.model.config.force_inference_output = True + has_control = False + lat_size = latent_model_input.shape[-1] + if latent_model_input.shape[1] == 32: + # chunk it and stack it on batch dimension + # dont update batch size for img_its + lat, control = torch.chunk(latent_model_input, 2, dim=1) + latent_model_input = torch.cat([lat, control], dim=-1) + has_control = True + + dtype = self.model.dtype + device = self.device_torch + + text_embeds = text_embeddings.text_embeds + # run the to for the list + text_embeds = [te.to(device, dtype=dtype) for te in text_embeds] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timesteps=timestep, + encoder_hidden_states_t5=text_embeds[0], + encoder_hidden_states_llama3=text_embeds[1], + pooled_embeds=text_embeddings.pooled_embeds.to(device, dtype=dtype), + return_dict=False, + )[0] + + if has_control: + noise_pred = -1.0 * noise_pred[..., :lat_size] + else: + noise_pred = -1.0 * noise_pred + + return noise_pred diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_model.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7bba831e55380317ac13a2fadd15fc91a9669d82 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -0,0 +1,453 @@ +import os +from typing import TYPE_CHECKING, List, Optional + +import einops +import torch +import torchvision +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from diffusers import AutoencoderKL, TorchAoConfig +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.mask import generate_random_mask, random_dialate_mask +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer, TorchAoConfig as TorchAoConfigTransformers +from .src.pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline +from .src.models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel +from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler +from transformers import LlamaForCausalLM, PreTrainedTokenizerFast +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from tqdm import tqdm +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast +) + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 3.0 +} + +# LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" +LLAMA_MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-Instruct" +BASE_MODEL_PATH = "HiDream-ai/HiDream-I1-Full" + + +class HidreamModel(BaseModel): + arch = "hidream" + hidream_transformer_class = HiDreamImageTransformer2DModel + hidream_pipeline_class = HiDreamImagePipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['HiDreamImageTransformer2DModel'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + # HiDream-ai/HiDream-I1-Full + self.print_and_status_update("Loading HiDream model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + extras_path = self.model_config.extras_name_or_path + + llama_model_path = self.model_config.model_kwargs.get('llama_model_path', LLAMA_MODEL_PATH) + + scheduler = HidreamModel.get_train_scheduler() + + self.print_and_status_update("Loading llama 8b model") + + tokenizer_4 = PreTrainedTokenizerFast.from_pretrained( + llama_model_path, + use_fast=False + ) + + text_encoder_4 = LlamaForCausalLM.from_pretrained( + llama_model_path, + output_hidden_states=True, + output_attentions=True, + torch_dtype=torch.bfloat16, + ) + text_encoder_4.to(self.device_torch, dtype=dtype) + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing llama 8b model") + quantization_type = get_qtype(self.model_config.qtype_te) + quantize(text_encoder_4, weights=quantization_type) + freeze(text_encoder_4) + + if self.low_vram: + # unload it for now + text_encoder_4.to('cpu') + + flush() + + self.print_and_status_update("Loading transformer") + + transformer = self.hidream_transformer_class.from_pretrained( + model_path, + subfolder="transformer", + torch_dtype=torch.bfloat16 + ) + + if not self.low_vram: + transformer.to(self.device_torch, dtype=dtype) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantization_type = get_qtype(self.model_config.qtype) + if self.low_vram: + # move and quantize only certain pieces at a time. + all_blocks = list(transformer.double_stream_blocks) + list(transformer.single_stream_blocks) + self.print_and_status_update(" - quantizing transformer blocks") + for block in tqdm(all_blocks): + block.to(self.device_torch, dtype=dtype) + quantize(block, weights=quantization_type) + freeze(block) + block.to('cpu') + # flush() + + self.print_and_status_update(" - quantizing extras") + transformer.to(self.device_torch, dtype=dtype) + quantize(transformer, weights=quantization_type) + freeze(transformer) + else: + quantize(transformer, weights=quantization_type) + freeze(transformer) + + if self.low_vram: + # unload it for now + transformer.to('cpu') + + flush() + + self.print_and_status_update("Loading vae") + + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + + self.print_and_status_update("Loading clip encoders") + + text_encoder = CLIPTextModelWithProjection.from_pretrained( + extras_path, + subfolder="text_encoder", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + tokenizer = CLIPTokenizer.from_pretrained( + extras_path, + subfolder="tokenizer" + ) + + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + extras_path, + subfolder="text_encoder_2", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + tokenizer_2 = CLIPTokenizer.from_pretrained( + extras_path, + subfolder="tokenizer_2" + ) + + flush() + self.print_and_status_update("Loading T5 encoders") + + text_encoder_3 = T5EncoderModel.from_pretrained( + extras_path, + subfolder="text_encoder_3", + torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantization_type = get_qtype(self.model_config.qtype_te) + quantize(text_encoder_3, weights=quantization_type) + freeze(text_encoder_3) + flush() + + tokenizer_3 = T5Tokenizer.from_pretrained( + extras_path, + subfolder="tokenizer_3" + ) + flush() + + if self.low_vram: + self.print_and_status_update("Moving everything to device") + # move it all back + transformer.to(self.device_torch, dtype=dtype) + vae.to(self.device_torch, dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder_2.to(self.device_torch, dtype=dtype) + text_encoder_4.to(self.device_torch, dtype=dtype) + text_encoder_3.to(self.device_torch, dtype=dtype) + + # set to eval mode + # transformer.eval() + vae.eval() + text_encoder.eval() + text_encoder_2.eval() + text_encoder_4.eval() + text_encoder_3.eval() + + pipe = self.hidream_pipeline_class( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + text_encoder_3=text_encoder_3, + tokenizer_3=tokenizer_3, + text_encoder_4=text_encoder_4, + tokenizer_4=tokenizer_4, + transformer=transformer, + ) + + flush() + + text_encoder_list = [text_encoder, text_encoder_2, text_encoder_3, text_encoder_4] + tokenizer_list = [tokenizer, tokenizer_2, tokenizer_3, tokenizer_4] + + for te in text_encoder_list: + # set the dtype + te.to(self.device_torch, dtype=dtype) + # freeze the model + freeze(te) + # set to eval mode + te.eval() + # set the requires grad to false + te.requires_grad_(False) + + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder_list # list of text encoders + self.tokenizer = tokenizer_list # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=3.0, + use_dynamic_shifting=False + ) + + pipeline: HiDreamImagePipeline = HiDreamImagePipeline( + scheduler=scheduler, + vae=self.vae, + text_encoder=self.text_encoder[0], + tokenizer=self.tokenizer[0], + text_encoder_2=self.text_encoder[1], + tokenizer_2=self.tokenizer[1], + text_encoder_3=self.text_encoder[2], + tokenizer_3=self.tokenizer[2], + text_encoder_4=self.text_encoder[3], + tokenizer_4=self.tokenizer[3], + transformer=unwrap_model(self.model), + aggressive_unloading=self.low_vram + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: HiDreamImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + batch_size = latent_model_input.shape[0] + with torch.no_grad(): + if latent_model_input.shape[-2] != latent_model_input.shape[-1]: + B, C, H, W = latent_model_input.shape + pH, pW = H // self.model.config.patch_size, W // self.model.config.patch_size + + img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) + img_ids = torch.zeros(pH, pW, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] + img_ids = img_ids.reshape(pH * pW, -1) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad[:pH*pW, :] = img_ids + + img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device) + img_sizes = torch.cat([img_sizes] * batch_size, dim=0) + img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device) + img_ids = torch.cat([img_ids] * batch_size, dim=0) + else: + img_sizes = img_ids = None + + dtype = self.model.dtype + device = self.device_torch + + # Pack the latent + if latent_model_input.shape[-2] != latent_model_input.shape[-1]: + B, C, H, W = latent_model_input.shape + patch_size = self.transformer.config.patch_size + pH, pW = H // patch_size, W // patch_size + out = torch.zeros( + (B, C, self.transformer.max_seq, patch_size * patch_size), + dtype=latent_model_input.dtype, + device=latent_model_input.device + ) + latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) + out[:, :, 0:pH*pW] = latent_model_input + latent_model_input = out + + text_embeds = text_embeddings.text_embeds + # run the to for the list + text_embeds = [te.to(device, dtype=dtype) for te in text_embeds] + + noise_pred = self.transformer( + hidden_states = latent_model_input, + timesteps = timestep, + encoder_hidden_states = text_embeds, + pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype), + img_sizes = img_sizes, + img_ids = img_ids, + return_dict = False, + )[0] + noise_pred = -noise_pred + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) + max_sequence_length = 128 + prompt_embeds, pooled_prompt_embeds = self.pipeline._encode_prompt( + prompt = prompt, + prompt_2 = prompt, + prompt_3 = prompt, + prompt_4 = prompt, + device = self.device_torch, + dtype = self.torch_dtype, + num_images_per_prompt = 1, + max_sequence_length = max_sequence_length, + ) + pe = PromptEmbeds( + [prompt_embeds, pooled_prompt_embeds] + ) + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.double_stream_blocks[0].block.attn1.to_q.weight.requires_grad + + def get_te_has_grad(self): + # assume no one wants to finetune 4 text encoders. + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: HiDreamImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ['double_stream_blocks', 'single_stream_blocks'] + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "hidream_i1" + diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cf58ee0bc4ca458fddb53c7f6f674bbaaa82d8d6 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py @@ -0,0 +1,547 @@ +import os +from typing import List, Optional + +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from safetensors.torch import load_file, save_file +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager + +from transformers import AutoProcessor +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig +from .src.hidream_o1.qwen3_vl_transformers import Qwen3VLForConditionalGeneration +from .src.hidream_o1.pipeline import HiDreamO1Pipeline, DEFAULT_NOISE_SCALE +from toolkit.models.FakeVAE import FakeVAE +from typing import TYPE_CHECKING +from .src.hidream_o1.model_config import model_config + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": False, +} + +_GLOBAL_NOISE_SCALE = DEFAULT_NOISE_SCALE + + +class HidreamO1FlowmatchScheduler(CustomFlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + self.noise_scale = kwargs.get("noise_scale", DEFAULT_NOISE_SCALE) + # remove noise_scale from kwargs so it doesn't get passed to the parent class + kwargs.pop("noise_scale", None) + super().__init__(*args, **kwargs) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + t_01 = (timesteps / 1000).to(original_samples.device) + scaled_noise = noise * self.noise_scale + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * scaled_noise + return noisy_model_input + + +def add_special_tokens(tokenizer): + """Attach the special-token shortcuts that the pipeline relies on.""" + tokenizer.boi_token = "<|boi_token|>" + tokenizer.bor_token = "<|bor_token|>" + tokenizer.eor_token = "<|eor_token|>" + tokenizer.bot_token = "<|bot_token|>" + tokenizer.tms_token = "<|tms_token|>" + + +def get_tokenizer(processor): + from transformers import PreTrainedTokenizerBase + + if isinstance(processor, PreTrainedTokenizerBase): + return processor + return processor.tokenizer + + +class FakeConfig: + pass + + +class FakeTextEncoder(torch.nn.Module): + def __init__(self, scaling_factor=1.0): + super().__init__() + self._dtype = torch.float32 + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.config = FakeConfig() + self.config.scaling_factor = scaling_factor + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + # mimic to from torch + def to(self, *args, **kwargs): + # pull out dtype and device if they exist + if "dtype" in kwargs: + self._dtype = kwargs["dtype"] + if "device" in kwargs: + self._device = kwargs["device"] + return super().to(*args, **kwargs) + + +class HidreamO1Model(BaseModel): + arch = "hidream_o1" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.use_old_lokr_format = False + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["Qwen3VLForConditionalGeneration"] + self.noise_scale = self.model_config.model_kwargs.get( + "noise_scale", DEFAULT_NOISE_SCALE + ) + self.noise_scale_inference = self.model_config.model_kwargs.get( + "noise_scale_inference", self.noise_scale + ) + print(f"Using noise scale: {self.noise_scale}") + global _GLOBAL_NOISE_SCALE + _GLOBAL_NOISE_SCALE = self.noise_scale + self.is_comfy_weight = False # save as single file if true + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return HidreamO1FlowmatchScheduler( + **scheduler_config, noise_scale=_GLOBAL_NOISE_SCALE + ) + + def get_bucket_divisibility(self): + return 32 # patch size + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading HidreamO1 model") + model_path = self.model_config.name_or_path + + self.print_and_status_update("Loading transformer") + + try: + processor = AutoProcessor.from_pretrained(model_path) + except Exception as e: + print( + f"Failed to load processor from model path {model_path}, trying original path. Error: {e}" + ) + processor_path = self.model_config.extras_name_or_path + if processor_path.endswith(".safetensors"): + processor_path = "HiDream-ai/HiDream-O1-Image" + processor = AutoProcessor.from_pretrained(processor_path) + + tokenizer = get_tokenizer(processor) + add_special_tokens(tokenizer) + + if model_path.endswith(".safetensors"): + self.is_comfy_weight = True + self.print_and_status_update( + "Model is in safetensors format, loading with safetensors" + ) + state_dict = load_file(model_path) + + for key, value in state_dict.items(): + state_dict[key] = value.to(dtype=dtype) + + # comfy ui is missing the lm head. It isnt used, but our model needs it for now + state_dict["lm_head.weight"] = torch.zeros( + 151936, 4096, dtype=torch.bfloat16, device="cpu" + ) + + # transformer.load_state_dict(state_dict, assign=True) + transformer = Qwen3VLForConditionalGeneration.from_pretrained( + None, + config=Qwen3VLConfig(**model_config), + state_dict=state_dict, + torch_dtype=self.torch_dtype, + ) + del state_dict # free memory + else: + transformer = Qwen3VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=self.torch_dtype, + ) + flush() + if not self.model_config.low_vram: + transformer.to(self.device_torch) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[], + ) + + flush() + + # move over to device now if low vram + if self.model_config.low_vram: + transformer.to(self.device_torch) + + # fake ones so the trainer doesnt break + vae = FakeVAE().to(self.device_torch, dtype=dtype) + text_encoder = FakeTextEncoder().to(self.device_torch, dtype=dtype) + + self.noise_scheduler = HidreamO1Model.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: HiDreamO1Pipeline = HiDreamO1Pipeline( + scheduler=self.noise_scheduler, + processor=processor, + model=None, + **kwargs, + ) + pipe.model = transformer + + self.print_and_status_update("Preparing Model") + + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = processor + self.model = pipe.model + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = HidreamO1Model.get_train_scheduler() + + pipe: HiDreamO1Pipeline = HiDreamO1Pipeline( + scheduler=scheduler, + processor=self.tokenizer, + model=None, + ) + pipe.model = self.transformer + + return pipe + + def encode_images(self, image_list: torch.Tensor, device=None, dtype=None): + if self.vae.device == torch.device("cpu"): + self.vae.to(self.device_torch) + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # not needed since there is not a latent space + return image_list.to(device, dtype=dtype) + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if self.vae.device == torch.device("cpu"): + self.vae.to(self.device_torch) + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # not needed since there is not a latent space + return latents.to(device, dtype=dtype) + + def generate_single_image( + self, + pipeline: HiDreamO1Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + # prompt=gen_config.prompt, + prompt_input_ids=conditional_embeds.text_embeds[0], + # negative_prompt=gen_config.negative_prompt, + negative_prompt_input_ids=unconditional_embeds.text_embeds[0], + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + generator=generator, + noise_scale=self.noise_scale_inference, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: AdvancedPromptEmbeds, + batch: "DataLoaderBatchDTO", + **kwargs, + ): + import einops + from .src.hidream_o1.pipeline import PATCH_SIZE, T_EPS + + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + device = self.device_torch + in_dtype = latent_model_input.dtype + bs, _, h_pix, w_pix = latent_model_input.shape + h_patches = h_pix // PATCH_SIZE + w_patches = w_pix // PATCH_SIZE + + # (B, C, H, W) -> (B, H/p * W/p, C * p * p) + z = einops.rearrange( + latent_model_input, + "B C (H p1) (W p2) -> B (H W) (C p1 p2)", + p1=PATCH_SIZE, + p2=PATCH_SIZE, + ).to(device) + + model_config = self.model.config + pad_token_id = getattr(model_config, "pad_token_id", 0) or 0 + + with torch.no_grad(): + # Build per-sample conditioning, then left-pad the text portion so + # the boi/tms + vision-token suffix stays at the end of the + # sequence (the t2i layout assumes vision tokens are at the tail). + per_sample = [] + for b in range(bs): + tokens = text_embeddings.text_embeds[b] + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + per_sample.append( + self.pipeline.build_conditioning_sample( + tokens.to(device), + h_pix, + w_pix, + ) + ) + + max_seq_len = max(s["input_ids"].shape[-1] for s in per_sample) + ids_l, pos_l, tt_l, vm_l, mask_l = [], [], [], [], [] + for s in per_sample: + ids = s["input_ids"].to(device) + pos = s["position_ids"].to(device) + tt = s["token_types"].to(device) + vm = s["vinput_mask"].to(device) + seq_len = ids.shape[-1] + pad_len = max_seq_len - seq_len + + if pad_len > 0: + ids = torch.cat( + [ + torch.full( + (1, pad_len), + pad_token_id, + dtype=ids.dtype, + device=device, + ), + ids, + ], + dim=-1, + ) + pos = torch.cat( + [ + torch.ones((3, 1, pad_len), dtype=pos.dtype, device=device), + pos, + ], + dim=-1, + ) + tt = torch.cat( + [ + torch.zeros((1, pad_len), dtype=tt.dtype, device=device), + tt, + ], + dim=-1, + ) + vm = torch.cat( + [ + torch.zeros((1, pad_len), dtype=vm.dtype, device=device), + vm, + ], + dim=-1, + ) + mask = torch.cat( + [ + torch.zeros((1, pad_len), dtype=torch.long, device=device), + torch.ones((1, seq_len), dtype=torch.long, device=device), + ], + dim=-1, + ) + else: + mask = torch.ones((1, seq_len), dtype=torch.long, device=device) + + ids_l.append(ids) + pos_l.append(pos) + tt_l.append(tt) + vm_l.append(vm) + mask_l.append(mask) + + input_ids = torch.cat(ids_l, dim=0) + position_ids = torch.cat(pos_l, dim=1) # (3, B, S) + token_types = torch.cat(tt_l, dim=0) + vinput_mask = torch.cat(vm_l, dim=0) + attention_mask = torch.cat(mask_l, dim=0) + + # Model wants timestep as denoising progress in (0, 1) where 1=clean. + t_pixeldit = (1.0 - timestep.float() / 1000.0).to(device) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask if bs > 1 else None, + vinputs=z, + timestep=t_pixeldit.reshape(-1), + token_types=token_types, + use_flash_attn=False, + ) + x_pred = outputs.x_pred # (B, S, C*p*p) over the full padded sequence + + # Pull the vision-token positions only. + vision_pred = torch.stack( + [x_pred[b][vinput_mask[b].bool()] for b in range(bs)], + dim=0, + ) # (B, image_len, C*p*p) + + x0_pred = einops.rearrange( + vision_pred, + "B (H W) (C p1 p2) -> B C (H p1) (W p2)", + H=h_patches, + W=w_patches, + p1=PATCH_SIZE, + p2=PATCH_SIZE, + ) + + # Model emits an x0-prediction; convert to flow-matching velocity + # (x_1 - x_0) so it matches the loss target from get_loss_target. + sigma = (timestep.float() / 1000.0).clamp_min(T_EPS).to(device) + while sigma.dim() < latent_model_input.dim(): + sigma = sigma.unsqueeze(-1) + pred = (latent_model_input.float().to(device) - x0_pred.float()) / sigma + return pred.to(in_dtype) + + def get_prompt_embeds(self, prompt: list) -> AdvancedPromptEmbeds: + if not isinstance(prompt, list): + prompt = [prompt] + # empty, we cannot use them with this omni model anyway, but will break trainer if they do not exist + token_list = [self.pipeline.encode_prompt(p) for p in prompt] + pe = AdvancedPromptEmbeds(text_embeds=token_list) + pe._frozen_dtype_keys = ["text_embeds"] + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + from toolkit.util.quantize import dequantize_if_quantized + transformer: Qwen3VLForConditionalGeneration = unwrap_model(self.model) + if self.is_comfy_weight: + sd = transformer.state_dict() + save_dict = {} + for key, value in sd.items(): + if "lm_head.weight" in key: + continue # comfy checkpoint doesnt have the lm head, so skip it + # dequantize any quantized (e.g. torchao) weights so we save plain full precision tensors + save_dict[key] = dequantize_if_quantized(value).clone().to("cpu", dtype=save_dtype) + + if not output_path.endswith(".safetensors"): + output_path += ".safetensors" + meta = get_meta_for_safetensors(meta, name=self.arch) + save_file(save_dict, output_path, metadata=meta) + else: + transformer.save_pretrained( + save_directory=output_path, + safe_serialization=True, + ) + + # save processor + self.tokenizer.save_pretrained(output_path) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + noise_scale = self.noise_scale + return (noise * noise_scale - batch.latents).detach() + + def get_base_model_version(self): + return self.arch + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["model.language_model.layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_key = new_key.replace(".model.", ".") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.model.") + # to load legacy keys + new_key = new_key.replace("transformer.model.model.", "transformer.model.") + new_sd[new_key] = value + return new_sd diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b64a05bbf61bc300f390107a82617736766050c3 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/__init__.py @@ -0,0 +1,2 @@ +from .models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel +from .pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/model_config.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ded2ff2cbe55a1f70a1dfda4250b74b265dfd3d2 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/model_config.py @@ -0,0 +1,52 @@ +model_config = { + "architectures": ["Qwen3VLForConditionalGeneration"], + "image_token_id": 151655, + "model_type": "qwen3_vl", + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_type": "default", + }, + "rope_theta": 5000000, + "use_cache": True, + "vocab_size": 151936, + }, + "tie_word_embeddings": False, + "transformers_version": "4.57.0.dev0", + "video_token_id": 151656, + "vision_config": { + "deepstack_visual_indexes": [8, 16, 24], + "depth": 27, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4304, + "model_type": "qwen3_vl", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 4096, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, +} diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8e24da332252b8b7d425817037fefbf7769198b5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/pipeline.py @@ -0,0 +1,455 @@ +from typing import List, Optional, Union + +import einops +import numpy as np +import torch +from PIL import Image +import torchvision.transforms.v2 as transforms + +from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput +from dataclasses import dataclass + + +TIMESTEP_TOKEN_NUM = 1 +DEFAULT_NOISE_SCALE = 8.0 +T_EPS = 0.001 +PATCH_SIZE = 32 + +TENSOR_TRANSFORM = transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def round_to_patch(dim: int, patch: int = PATCH_SIZE) -> int: + return max(patch, int(dim // patch * patch)) + + +def _get_rope_index_t2i( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + input_ids: torch.LongTensor, + image_grid_thw: torch.LongTensor, + skip_vision_start_token: List[int], + fix_point: int = 4096, +): + """Compute mrope position ids for the t2i case used by HiDream-O1.""" + attention_mask = torch.ones_like(input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + for i, ids_row in enumerate(input_ids): + ids_row = ids_row[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(ids_row == vision_start_token_id).squeeze( + 1 + ) + vision_tokens = ids_row[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum().item() + video_nums = (vision_tokens == video_token_id).sum().item() + input_tokens = ids_row.tolist() + + llm_pos_ids_list = [] + st = 0 + image_index = 0 + video_index = 0 + remain_images, remain_videos = image_nums, video_nums + local_fix_point = fix_point + + for _ in range(image_nums + video_nums): + ed_image = ( + input_tokens.index(image_token_id, st) + if (image_token_id in input_tokens and remain_images > 0) + else len(input_tokens) + 1 + ) + ed_video = ( + input_tokens.index(video_token_id, st) + if (video_token_id in input_tokens and remain_videos > 0) + else len(input_tokens) + 1 + ) + if ed_image < ed_video: + t, h, w = image_grid_thw[image_index].tolist() + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = image_grid_thw[video_index].tolist() + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t = t + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + + text_len = ed - st - skip_vision_start_token[image_index - 1] + text_len = max(0, text_len) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + + if skip_vision_start_token[image_index - 1]: + if local_fix_point > 0: + local_fix_point = local_fix_point - st_idx + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + local_fix_point + st_idx + ) + local_fix_point = 0 + else: + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + + return position_ids + + +def _build_t2i_sample_from_input_ids( + input_ids: torch.Tensor, + height: int, + width: int, + model_config, + attention_mask: Optional[torch.Tensor] = None, +): + """Build the full conditioning sample (position_ids/token_types/vinput_mask) + around an already-tokenized prompt.""" + image_token_id = model_config.image_token_id + video_token_id = model_config.video_token_id + vision_start_token_id = model_config.vision_start_token_id + image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE) + + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + image_grid_thw = torch.tensor( + [1, height // PATCH_SIZE, width // PATCH_SIZE], dtype=torch.int64 + ).unsqueeze(0) + + vision_tokens = ( + torch.zeros((1, image_len), dtype=input_ids.dtype, device=input_ids.device) + + image_token_id + ) + vision_tokens[0, 0] = vision_start_token_id + input_ids_pad = torch.cat([input_ids, vision_tokens], dim=-1) + + position_ids = _get_rope_index_t2i( + spatial_merge_size=1, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + input_ids=input_ids_pad, + image_grid_thw=image_grid_thw, + skip_vision_start_token=[1], + ) + + txt_seq_len = input_ids.shape[-1] + all_seq_len = position_ids.shape[-1] + + token_types = torch.zeros((1, all_seq_len), dtype=input_ids.dtype) + bgn = txt_seq_len - TIMESTEP_TOKEN_NUM + token_types[0, bgn : bgn + image_len + TIMESTEP_TOKEN_NUM] = 1 + token_types[0, txt_seq_len - TIMESTEP_TOKEN_NUM : txt_seq_len] = 3 + + vinput_mask = token_types == 1 + token_types_bin = (token_types > 0).to(token_types.dtype) + + sample = { + "input_ids": input_ids, + "position_ids": position_ids, + "token_types": token_types_bin, + "vinput_mask": vinput_mask, + } + if attention_mask is not None: + if attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + sample["attention_mask"] = attention_mask + return sample + + +@dataclass +class HiDreamO1PipelineOutput(BaseOutput): + images: List[Image.Image] + + +class HiDreamO1Pipeline(DiffusionPipeline): + """ + Diffusers-style inference pipeline for HiDream-O1 (base model). + + HiDream-O1 is a unified text/vision/diffusion model with no VAE — the + transformer directly predicts image patches in pixel space. This pipeline + keeps only the components needed for text-to-image inference. + """ + + model_cpu_offload_seq = "model" + + def __init__( + self, + model, + processor, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + self.register_modules(model=model, processor=processor, scheduler=scheduler) + + @property + def tokenizer(self): + return ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + + def _snap_resolution(self, width: int, height: int): + w, h = round_to_patch(width), round_to_patch(height) + if (w, h) != (width, height): + print(f"[hidream-o1] Resolution rounded from {width}x{height} to {w}x{h}") + return w, h + + def build_conditioning_sample( + self, + input_ids: torch.Tensor, + height: int, + width: int, + attention_mask: Optional[torch.Tensor] = None, + ): + """Build the per-sample conditioning dict (input_ids, position_ids, + token_types, vinput_mask) around already-tokenized text. Useful when + a training loop needs to batch samples manually.""" + return _build_t2i_sample_from_input_ids( + input_ids, + height, + width, + self.model.config, + attention_mask=attention_mask, + ) + + def encode_prompt(self, prompt: str) -> torch.Tensor: + """Apply the chat template + boi/tms suffix and tokenize. + Returns input_ids of shape (1, seq_len). Use these to precompute and + pass back into __call__ via `prompt_input_ids` / `negative_prompt_input_ids`.""" + tokenizer = self.tokenizer + boi_token = getattr(tokenizer, "boi_token", "<|boi_token|>") + tms_token = getattr(tokenizer, "tms_token", "<|tms_token|>") + + messages = [{"role": "user", "content": prompt}] + template_caption = ( + self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + boi_token + + tms_token * TIMESTEP_TOKEN_NUM + ) + return tokenizer.encode( + template_caption, return_tensors="pt", add_special_tokens=False + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = " ", + prompt_input_ids: Optional[torch.Tensor] = None, + negative_prompt_input_ids: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + height: int = 1440, + width: int = 2560, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + shift: float = 3.0, + generator: Optional[torch.Generator] = None, + seed: Optional[int] = None, + noise_scale: float = None, + output_type: str = "pil", + return_dict: bool = True, + ): + if noise_scale is None: + noise_scale = DEFAULT_NOISE_SCALE + if prompt is None and prompt_input_ids is None: + raise ValueError("Provide either `prompt` or `prompt_input_ids`.") + + def _unwrap_str(x): + if isinstance(x, list): + if len(x) != 1: + raise ValueError( + "HiDreamO1Pipeline currently supports batch size 1." + ) + return x[0] + return x + + prompt = _unwrap_str(prompt) + negative_prompt = _unwrap_str(negative_prompt) + + device = self._execution_device + dtype = torch.bfloat16 + model_config = self.model.config + + width, height = self._snap_resolution(width, height) + h_patches = height // PATCH_SIZE + w_patches = width // PATCH_SIZE + + do_cfg = guidance_scale > 1.0 + + if prompt_input_ids is None: + prompt_input_ids = self.encode_prompt(prompt) + if do_cfg and negative_prompt_input_ids is None: + if negative_prompt is None: + negative_prompt = " " + negative_prompt_input_ids = self.encode_prompt(negative_prompt) + + cond_sample = _build_t2i_sample_from_input_ids( + prompt_input_ids, + height, + width, + model_config, + attention_mask=prompt_attention_mask, + ) + uncond_sample = ( + _build_t2i_sample_from_input_ids( + negative_prompt_input_ids, + height, + width, + model_config, + attention_mask=negative_prompt_attention_mask, + ) + if do_cfg + else None + ) + + def _to_device(s): + return { + k: (v.to(device) if torch.is_tensor(v) else v) for k, v in s.items() + } + + cond_sample = _to_device(cond_sample) + if uncond_sample is not None: + uncond_sample = _to_device(uncond_sample) + + if generator is None: + if seed is None: + seed = 0 + generator = torch.Generator(device="cpu").manual_seed(seed + 1) + torch.manual_seed(seed + 1) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed + 1) + + noise = noise_scale * torch.randn( + (1, 3, height, width), generator=generator + ).to(device, dtype) + z = einops.rearrange( + noise, + "B C (H p1) (W p2) -> B (H W) (C p1 p2)", + p1=PATCH_SIZE, + p2=PATCH_SIZE, + ) + + if shift is not None and hasattr(self.scheduler, "set_shift"): + self.scheduler.set_shift(shift) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + def _forward_once(sample, z_in, t_pixeldit): + with torch.autocast(device.type, dtype=dtype): + kwargs = { + "input_ids": sample["input_ids"], + "position_ids": sample["position_ids"], + "vinputs": z_in, + "timestep": t_pixeldit.reshape(-1).to(device), + "token_types": sample["token_types"], + "use_flash_attn": True, + } + if "attention_mask" in sample: + kwargs["attention_mask"] = sample["attention_mask"] + outputs = self.model(**kwargs) + x_pred = outputs.x_pred + return x_pred[0, sample["vinput_mask"][0]].unsqueeze(0) + + for step_t in self.progress_bar(timesteps): + t_pixeldit = 1.0 - step_t.float() / 1000.0 + sigma = (step_t.float() / 1000.0).to(dtype=torch.float32).clamp_min(T_EPS) + + x_pred_cond = _forward_once(cond_sample, z.clone(), t_pixeldit) + v_cond = (x_pred_cond.float() - z.float()) / sigma + + if do_cfg: + x_pred_uncond = _forward_once(uncond_sample, z.clone(), t_pixeldit) + v_uncond = (x_pred_uncond.float() - z.float()) / sigma + v_guided = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v_guided = v_cond + + model_output = -v_guided + z = self.scheduler.step( + model_output.float(), + step_t.to(dtype=torch.float32), + z.float(), + return_dict=False, + )[0].to(dtype) + + img = (z + 1) / 2 + img = einops.rearrange( + img.cpu().float(), + "B (H W) (C p1 p2) -> B C (H p1) (W p2)", + H=h_patches, + W=w_patches, + p1=PATCH_SIZE, + p2=PATCH_SIZE, + ) + + if output_type == "pt": + images = img.clamp(0, 1) + elif output_type == "np": + images = np.clip(img.numpy().transpose(0, 2, 3, 1), 0, 1) + else: + arr = np.round( + np.clip(img[0].numpy().transpose(1, 2, 0) * 255, 0, 255) + ).astype(np.uint8) + images = [Image.fromarray(arr).convert("RGB")] + + if not return_dict: + return (images,) + return HiDreamO1PipelineOutput(images=images) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..87d8c57c3f4966628e44f7c29c24009a81f5c647 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py @@ -0,0 +1,2446 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +USE_BF16_ROPE = os.environ.get("USE_BF16_ROPE", "0") +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + is_torchdynamo_compiling, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import merge_with_config_defaults +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLConfig, + Qwen3VLTextConfig, + Qwen3VLVisionConfig, +) + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, + self.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim + ) + return hidden_states + + +class Qwen3VLVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm( + self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6 + ) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm( + x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x + ).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + is_causal: bool = False, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + elif is_causal: + q_len = query.shape[-2] + k_len = key_states.shape[-2] + causal_mask = torch.ones( + q_len, k_len, dtype=torch.bool, device=query.device + ).triu(diagonal=k_len - q_len + 1) + attn_weights = attn_weights.masked_fill( + causal_mask, torch.finfo(attn_weights.dtype).min + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3VLVisionAttention(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision( + query_states, key_states, cos, sin + ) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) + for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3VLVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + # Newer transformers removed "default" from ROPE_INIT_FUNCTIONS in + # favor of a static method on the rotary module. Match that. + if self.rope_type == "default": + self.rope_init_fn = self.compute_default_rope_parameters + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + @staticmethod + def compute_default_rope_parameters( + config=None, + device=None, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, float]: + """Standard RoPE inv_freq. Mirrors upstream Qwen3VL since newer + transformers removed "default" from ROPE_INIT_FUNCTIONS.""" + rope_params = getattr(config, "rope_parameters", None) + base = ( + rope_params["rope_theta"] + if rope_params and "rope_theta" in rope_params + else getattr(config, "rope_theta", 10000.0) + ) + dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / dim + ) + ) + return inv_freq, 1.0 + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + if USE_BF16_ROPE == "1": + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .to(device=x.device) + .expand(3, position_ids.shape[1], -1, 1) + ) + else: + inv_freq_expanded = ( + self.original_inv_freq[None, None, :, None] + .float() + .to(device=x.device) + .expand(3, position_ids.shape[1], -1, 1) + ) + # inv_freq_expanded = self.inv_freq[None, None, :, None].float().to(device=x.device).expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3VLTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen3VLTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3VLTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3VLTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3VLTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm( + self.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + key_states = self.k_norm( + self.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3VLTextMLP(config) + self.input_layernorm = Qwen3VLTextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3VLTextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3VLModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + x_pred: Optional[torch.FloatTensor] = None + mid_results: Optional[list] = None + + +@auto_docstring +class Qwen3VLPreTrainedModel(PreTrainedModel): + config: Qwen3VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3VLTextDecoderLayer, + "attentions": Qwen3VLTextAttention, + } + + +class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): + config: Qwen3VLVisionConfig + _no_split_modules = ["Qwen3VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding( + config.num_position_embeddings, config.hidden_size + ) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen3VLVisionBlock(config) for _ in range(config.depth)] + ) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange( + merge_size, device=device + ) # intra-block row offsets + intra_col = torch.arange( + merge_size, device=device + ) # intra-block col offsets + + # Compute full-resolution positions + row_idx = ( + block_rows[:, None, None, None] * merge_size + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge_size + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor( + idx_list, dtype=torch.long, device=self.pos_embed.weight.device + ) + weight_tensor = torch.tensor( + weight_list, + dtype=self.pos_embed.weight.dtype, + device=self.pos_embed.weight.device, + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[ + self.deepstack_visual_indexes.index(layer_num) + ](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VL, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLTextModel(Qwen3VLPreTrainedModel): + config: Qwen3VLTextConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer"] + + def __init__(self, config: Qwen3VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen3VLTextDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + return_mid_results_layers: Optional[list] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + cache_position (`torch.LongTensor` of shape `(seqlen,)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Used to index into the + key/value cache for incremental decoding. + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + return_mid_results_layers (`list[int]`, *optional*): + Indices of decoder layers whose hidden states should be collected and returned as + intermediate results on the output's `mid_results` attribute. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + mid_results = [] if return_mid_results_layers else None + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # === Memory profiling for decoder loop (gated by DEBUG_MEM=1) === + import os as _os + + _mem_debug = _os.environ.get("DEBUG_MEM", "0") == "1" + _gc_count = 0 + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + if self.gradient_checkpointing and torch.is_grad_enabled(): + # Use HuggingFace's _gradient_checkpointing_func which already has + # use_reentrant=False baked in from gradient_checkpointing_enable(). + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + _gc_count += 1 + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # Per-layer memory logging (gated by DEBUG_MEM=1) + if _mem_debug and (layer_idx % 4 == 0 or layer_idx == len(self.layers) - 1): + _a = torch.cuda.memory_allocated() / 1e9 + _rank = int(_os.environ.get("RANK", 0)) + print( + f"[MEM][rank{_rank}][decoder] layer {layer_idx:2d}/{len(self.layers)}: alloc={_a:.2f}GB", + flush=True, + ) + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range( + len(deepstack_visual_embeds) + ): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + if ( + return_mid_results_layers is not None + and layer_idx in return_mid_results_layers + ): + mid_results.append(hidden_states) + + _a = torch.cuda.memory_allocated() / 1e9 + if _mem_debug: + _rank = int(_os.environ.get("RANK", 0)) + print( + f"[MEM][rank{_rank}][decoder] LOOP END: alloc={_a:.2f}GB, GC_used={_gc_count}/{len(self.layers)} layers", + flush=True, + ) + + hidden_states = self.norm(hidden_states) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + output.mid_results = mid_results + return output + + def _deepstack_process( + self, + hidden_states: torch.Tensor, + visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor, + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +class BottleneckPatchEmbed(nn.Module): + def __init__( + self, config, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True + ): + super().__init__() + self.proj1 = nn.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False) + self.proj2 = nn.Linear(pca_dim, embed_dim, bias=bias) + self.initialize_weights() + + def initialize_weights(self): + w1 = self.proj1.weight.data + nn.init.xavier_uniform_(w1.view([w1.shape[0], -1])) + w2 = self.proj2.weight.data + nn.init.xavier_uniform_(w2.view([w2.shape[0], -1])) + nn.init.constant_(self.proj2.bias, 0) + + def forward(self, x): + x = self.proj2(self.proj1(x)) + return x + + +class FinalLayer(nn.Module): + def __init__(self, config, hidden_size, patch_size, out_channels): + super().__init__() + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, adaln_input=None): + x = self.linear(x) + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, config, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t * 1000, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + return t_emb + + +@auto_docstring +class Qwen3VLModel(Qwen3VLPreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + self.patch_size = 32 + self.in_channels = 3 + hidden_size = config.text_config.hidden_size + bottleneck_dim = hidden_size // 4 + + self.t_embedder1 = TimestepEmbedder(self.config, hidden_size) + self.x_embedder = BottleneckPatchEmbed( + self.config, + patch_size=self.patch_size, + in_chans=self.in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # self.t_embedder2 = TimestepEmbedder(self.config, hidden_size) + self.t_embedder2 = None + self.final_layer2 = FinalLayer( + self.config, + hidden_size=hidden_size, + patch_size=self.patch_size, + out_channels=self.in_channels, + ) + self.tms_token_id = 151673 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: Optional[torch.LongTensor] = None, + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ) + split_sizes = ( + image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 + ).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.video_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = ( + special_image_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if ( + image_features is not None + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = ( + special_video_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if ( + video_features is not None + and inputs_embeds[special_video_mask].numel() != video_features.numel() + ): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + def _run_decoder_flash( + self, inputs_embeds, position_ids, token_types, return_mid_results_layers=None + ): + """Run decoder layers with two-pass attention. + + Replicates the Megatron attention pattern: + 1. Causal attention on AR tokens only (text) + 2. Full (bidirectional) attention on ALL tokens + 3. Replace AR positions with causal result (index_copy) + + This ensures AR tokens only attend causally to other AR tokens, + while gen tokens attend bidirectionally to everything. + + Uses the transformers attention dispatch (ALL_ATTENTION_FUNCTIONS), + so any backend works (sdpa by default, flash_attention_2 if + activated via config._attn_implementation). + + Args: + inputs_embeds: [batch, total_seq_len, hidden] + position_ids: [3, batch, total_seq_len] - 3D RoPE positions + token_types: [batch, total_seq_len] - 0=AR, 1=gen + """ + text_model = self.language_model + + # Compute rotary position embeddings + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + elif position_ids.ndim == 3 and position_ids.shape[0] == 4: + position_ids = position_ids[1:] # drop text_position_ids dim + position_embeddings = text_model.rotary_emb(inputs_embeds, position_ids) + cos, sin = position_embeddings + + # Precompute AR token indices (same layout for all batch items) + is_gen = token_types[0].bool() # [seq_len] + idx_ar = torch.nonzero(~is_gen, as_tuple=False).squeeze(-1) # [n_ar] + + hidden_states = inputs_embeds + mid_results = [] if return_mid_results_layers else None + + use_gc = text_model.gradient_checkpointing and torch.is_grad_enabled() + + def _two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): + """Two-pass attention layer forward compatible with FSDP2. + + Calls decoder_layer(...) through its __call__ to trigger FSDP hooks + (which swap DTensor parameters to plain tensors), with self_attn.forward + temporarily replaced by a custom two-pass attention implementation + that goes through the transformers attention dispatch. + """ + original_attn_forward = decoder_layer.self_attn.forward + + def _custom_two_pass_attn( + hidden_states, position_embeddings, attention_mask=None, **kwargs + ): + attn = decoder_layer.self_attn + input_shape = hidden_states.shape[:-1] + head_dim = attn.head_dim + hidden_shape = (*input_shape, -1, head_dim) + + # Q, K, V projections in [B, H, S, D] + q = attn.q_norm( + attn.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + k = attn.k_norm( + attn.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + v = attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + # Apply rotary position embedding + cos_pe, sin_pe = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos_pe, sin_pe) + + scaling = head_dim**-0.5 + + # Attention dispatch — sdpa/eager by default, flash_attention_2 + # if activated via config._attn_implementation. + attention_interface: Callable = eager_attention_forward + if attn.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + attn.config._attn_implementation + ] + + # --- Two-pass attention --- + # Pass 1: causal on AR tokens only (slice on seq dim) + q_ar = q[:, :, idx_ar].contiguous() + k_ar = k[:, :, idx_ar].contiguous() + v_ar = v[:, :, idx_ar].contiguous() + out_ar, _ = attention_interface( + attn, + q_ar, + k_ar, + v_ar, + attention_mask=None, + dropout=0.0, + scaling=scaling, + is_causal=True, + ) # [B, n_ar, H, D] + + # Pass 2: full (bidirectional) on all tokens + out_full, _ = attention_interface( + attn, + q, + k, + v, + attention_mask=None, + dropout=0.0, + scaling=scaling, + is_causal=False, + ) # [B, S, H, D] + + # Replace AR positions with causal result + out_full = out_full.clone() + out_full[:, idx_ar] = out_ar + + # Output projection + attn_output = out_full.reshape(*input_shape, -1).contiguous() + attn_output = attn.o_proj(attn_output) + return attn_output, None + + # Temporarily disable gradient checkpointing on the decoder layer + # to avoid nested checkpointing (the outer loop handles GC). + _saved_gc = decoder_layer.gradient_checkpointing + decoder_layer.gradient_checkpointing = False + decoder_layer.self_attn.forward = _custom_two_pass_attn + try: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=(cos, sin), + ) + finally: + decoder_layer.self_attn.forward = original_attn_forward + decoder_layer.gradient_checkpointing = _saved_gc + + return hidden_states + + for layer_idx, decoder_layer in enumerate(text_model.layers): + if use_gc: + hidden_states = torch.utils.checkpoint.checkpoint( + _two_pass_layer_forward, + hidden_states, + decoder_layer, + cos, + sin, + idx_ar, + use_reentrant=False, + ) + else: + hidden_states = _two_pass_layer_forward( + hidden_states, + decoder_layer, + cos, + sin, + idx_ar, + ) + + if ( + return_mid_results_layers is not None + and layer_idx in return_mid_results_layers + ): + mid_results.append(hidden_states) + + # Final layer norm + hidden_states = text_model.norm(hidden_states) + return hidden_states, mid_results + + def _forward_generation( + self, + input_ids, + position_ids, + vinputs, + timestep, + token_types, + attention_mask=None, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + use_flash_attn=False, + return_mid_results_layers=None, + **kwargs, + ): + """Forward pass for image generation (denoising step). + + Args: + input_ids: [batch, txt_seq_len] - text token IDs (without image tokens) + position_ids: [3, batch, total_seq_len] - 3D RoPE positions covering text+image + vinputs: [batch, img_tokens, patch_dim] - patchified noise input + timestep: [batch] - timestep values (scalar per sample) + token_types: [batch, total_seq_len] or [total_seq_len, 1] - 0=AR, >0=gen + attention_mask: ignored (created internally for non-flash path) + pixel_values: optional image pixel values for conditioned generation + pixel_values_videos: optional video pixel values + image_grid_thw: optional image grid info + video_grid_thw: optional video grid info + use_flash_attn: if True, use flash attention with two-pass approach + + Returns: + Qwen3VLModelOutputWithPast with x_pred field set. + """ + # 1. Get text token embeddings + inputs_embeds = self.get_input_embeddings()( + input_ids + ) # [batch, txt_seq_len, hidden] + + # 2. Process image/video embeddings if present (for image-conditioned generation) + if pixel_values is not None: + image_embeds, _ = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, _ = self.get_video_features( + pixel_values_videos, video_grid_thw + ) + video_embeds = torch.cat(video_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # 3. Embed timestep and replace tms_token positions + if isinstance(timestep, list): + timestep = torch.cat(timestep, dim=0) + timestep = timestep.to(inputs_embeds.device) + t_emb = self.t_embedder1(timestep) # [batch, hidden] + + tms_mask = input_ids == self.tms_token_id # [batch, txt_seq_len] + tms_mask_3d = tms_mask.unsqueeze(-1).expand_as(inputs_embeds) + t_emb_expanded = t_emb.unsqueeze(1).expand_as(inputs_embeds) + inputs_embeds = torch.where(tms_mask_3d, t_emb_expanded, inputs_embeds) + + # 4. Embed vinputs and append to sequence + if isinstance(vinputs, list): + vinputs = torch.cat(vinputs, dim=0) + vinputs = vinputs.to(inputs_embeds.device) + vinputs_embedded = self.x_embedder(vinputs).to( + inputs_embeds.dtype + ) # [batch, img_tokens, hidden] + inputs_embeds = torch.cat( + [inputs_embeds, vinputs_embedded], dim=1 + ) # [batch, total_seq_len, hidden] + + batch_size, total_seq_len, _ = inputs_embeds.shape + + # 5. Parse token_types to [batch, total_seq_len] + if isinstance(token_types, list): + token_types = torch.cat(token_types, dim=0) + token_types = token_types.to(inputs_embeds.device) + if token_types.dim() == 1: + token_types = token_types.unsqueeze(0) + elif ( + token_types.dim() == 2 + and token_types.shape[-1] == 1 + and token_types.shape[0] == total_seq_len + ): + # [total_seq_len, 1] -> [1, total_seq_len] + token_types = token_types.squeeze(-1).unsqueeze(0) + if token_types.shape[0] == 1 and batch_size > 1: + token_types = token_types.expand(batch_size, -1) + + # 6. Forward through decoder + mid_results = None + import os as _os2 + + _mem_debug2 = _os2.environ.get("DEBUG_MEM", "0") == "1" + + if _mem_debug2: + _rank2 = int(_os2.environ.get("RANK", 0)) + _a = torch.cuda.memory_allocated() / 1e9 + print( + f"[MEM][rank{_rank2}][_forward_gen] before decoder: alloc={_a:.2f}GB, " + f"total_seq_len={total_seq_len}, batch={batch_size}, flash={use_flash_attn}", + flush=True, + ) + + if use_flash_attn: + # Flash attention: two-pass approach (causal on AR + full on all → index_copy) + hidden_states, mid_results = self._run_decoder_flash( + inputs_embeds, + position_ids, + token_types, + return_mid_results_layers=return_mid_results_layers, + ) + else: + # Standard path: 4D attention mask (causal for AR, full for gen tokens) + dtype = inputs_embeds.dtype + min_val = torch.finfo(dtype).min + attn_masks = [] + for b in range(batch_size): + causal = torch.full( + (total_seq_len, total_seq_len), + min_val, + device=inputs_embeds.device, + dtype=dtype, + ) + causal = torch.triu( + causal, diagonal=1 + ) # lower tri + diag = 0 (allowed) + gen_positions = token_types[b].bool() # [total_seq_len] + causal[gen_positions, :] = 0 # gen tokens attend to everything + attn_masks.append(causal) + attention_mask_4d = torch.stack(attn_masks, dim=0).unsqueeze( + 1 + ) # [batch, 1, seq, seq] + + if _mem_debug2: + _rank2 = int(_os2.environ.get("RANK", 0)) + _mask_gb = ( + attention_mask_4d.element_size() * attention_mask_4d.numel() / 1e9 + ) + _a = torch.cuda.memory_allocated() / 1e9 + print( + f"[MEM][rank{_rank2}][_forward_gen] attn_mask_4d: shape={list(attention_mask_4d.shape)}, " + f"size={_mask_gb:.3f}GB, alloc={_a:.2f}GB", + flush=True, + ) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask_4d, + inputs_embeds=inputs_embeds, + use_cache=False, + return_mid_results_layers=return_mid_results_layers, + ) + hidden_states = outputs.last_hidden_state + if hasattr(outputs, "mid_results"): + mid_results = outputs.mid_results + + # 7. Apply final layer to get pixel predictions + x_pred = self.final_layer2(hidden_states) # [batch, total_seq_len, out_dim] + + return Qwen3VLModelOutputWithPast( + last_hidden_state=hidden_states, + x_pred=x_pred, + mid_results=mid_results, + ) + + @auto_docstring + @merge_with_config_defaults + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + vinputs: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, + token_types: Optional[torch.Tensor] = None, + use_flash_attn: bool = False, + return_mid_results_layers: Optional[list] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + cache_position (`torch.LongTensor` of shape `(seqlen,)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Used to index into the + key/value cache for incremental decoding. + vinputs (`torch.Tensor`, *optional*): + Visual inputs for the generation pathway. When provided, the forward call is dispatched to + `_forward_generation` instead of the standard understanding path. + timestep (`torch.Tensor`, *optional*): + Diffusion timestep tensor passed through to the generation forward when `vinputs` is provided. + token_types (`torch.Tensor`, *optional*): + Per-token type identifiers used by the generation forward to distinguish text vs. visual tokens. + use_flash_attn (`bool`, *optional*, defaults to `False`): + Whether to use a flash-attention kernel in the generation forward path. + return_mid_results_layers (`list[int]`, *optional*): + Indices of decoder layers whose hidden states should be collected and returned as + intermediate results on the output's `mid_results` attribute. + """ + # Dispatch to generation forward if vinputs is provided + if vinputs is not None: + return self._forward_generation( + input_ids=input_ids, + position_ids=position_ids, + vinputs=vinputs, + timestep=timestep, + token_types=token_types, + attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_flash_attn=use_flash_attn, + return_mid_results_layers=return_mid_results_layers, + **kwargs, + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features( + pixel_values, image_grid_thw + ) + image_embeds = torch.cat(image_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw + ) + video_embeds = torch.cat(video_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip( + deepstack_image_embeds, deepstack_video_embeds + ): + embed_joint = img_embed.new_zeros( + visual_pos_masks.sum(), img_embed.shape[-1] + ).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask + if not isinstance(attention_mask, dict) + else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal( + attention_mask_tensor[:, 0], dim1=1, dim2=2 + ) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = ( + attention_mask_tensor + / torch.finfo(attention_mask_tensor.dtype).min + ) + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if ( + prefill_compiled_stage or prefill_noncompiled_stage + ) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VL causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + x_pred: Optional[torch.FloatTensor] = None + mid_results: Optional[list] = None + + +class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3VLModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: Optional[torch.LongTensor] = None, + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + def enable_gradient_checkpointing(self): + self.model.visual.gradient_checkpointing_enable() + self.model.language_model.gradient_checkpointing_enable() + + @merge_with_config_defaults + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + vinputs: Optional[torch.Tensor] = None, + timestep: Optional[torch.Tensor] = None, + token_types: Optional[torch.Tensor] = None, + use_flash_attn: bool = False, + return_mid_results_layers: Optional[list] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + vinputs=vinputs, + timestep=timestep, + token_types=token_types, + use_flash_attn=use_flash_attn, + return_mid_results_layers=return_mid_results_layers, + **kwargs, + ) + + # Generation path: return x_pred directly + if vinputs is not None: + return Qwen3VLCausalLMOutputWithPast( + x_pred=outputs.x_pred, + mid_results=outputs.mid_results + if hasattr(outputs, "mid_results") + else None, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.text_config.vocab_size, + ) + + return Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor( + vision_start_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor( + image_token_id, dtype=torch.long, device=inputs_embeds.device + ) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor( + video_token_id, dtype=torch.long, device=inputs_embeds.device + ) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat( + [sample.repeat(*repeat_args) for sample in samples], dim=0 + ) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], + lengths=list(video_nums), + repeat_times=expand_size, + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave( + expand_size, dim=0 + ) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation( + model_kwargs["encoder_outputs"] + ) + + return input_ids, model_kwargs diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/attention.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..95f338cb9febd6c7306d8337c398ab1a045a3dc7 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/attention.py @@ -0,0 +1,106 @@ +import torch +from torch import nn +from typing import Optional +from diffusers.models.attention_processor import Attention +from diffusers.utils.torch_utils import maybe_allow_in_graph + +@maybe_allow_in_graph +class HiDreamAttention(Attention): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor = None, + out_dim: int = None, + single: bool = False + ): + super(Attention, self).__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + linear_cls = nn.Linear + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim) + self.to_k = linear_cls(self.inner_dim, self.inner_dim) + self.to_v = linear_cls(self.inner_dim, self.inner_dim) + self.to_out = linear_cls(self.inner_dim, self.out_dim) + self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps) + self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps) + + if not single: + self.to_q_t = linear_cls(query_dim, self.inner_dim) + self.to_k_t = linear_cls(self.inner_dim, self.inner_dim) + self.to_v_t = linear_cls(self.inner_dim, self.inner_dim) + self.to_out_t = linear_cls(self.inner_dim, self.out_dim) + self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) + self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) + + self.set_processor(processor) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, + norm_image_tokens: torch.FloatTensor, + image_tokens_masks: torch.FloatTensor = None, + norm_text_tokens: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.Tensor: + return self.processor( + self, + image_tokens = norm_image_tokens, + image_tokens_masks = image_tokens_masks, + text_tokens = norm_text_tokens, + rope = rope, + ) + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..abc2ff4086bb24de67fe075177299f100be8b8bd --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py @@ -0,0 +1,121 @@ +from typing import Optional +import torch +from .attention import HiDreamAttention + +# Try to import Flash Attention first +flash_attn_available = False +try: + from flash_attn_interface import flash_attn_func + USE_FLASH_ATTN3 = True + flash_attn_available = True +except ImportError: + try: + from flash_attn import flash_attn_func + USE_FLASH_ATTN3 = False + flash_attn_available = True + except ImportError: + USE_FLASH_ATTN3 = False + flash_attn_available = False + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + if flash_attn_available: + if USE_FLASH_ATTN3: + hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] + else: + hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) + else: + # Use torch's scaled dot-product attention as fallback + # Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim] + query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim] + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + # Restore original shape + hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim] + + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.to(query.dtype) + return hidden_states + +class HiDreamAttnProcessor_flashattn: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn: HiDreamAttention, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + dtype = image_tokens.dtype + batch_size = image_tokens.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) + value_i = attn.to_v(image_tokens) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if image_tokens_masks is not None: + key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) + value_t = attn.to_v_t(text_tokens) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == rope.shape[-3] * 2: + query, key = apply_rope(query, key, rope) + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, rope) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = attention(query, key, value) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/embeddings.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..5f57b36e1ac359bd6cc2c6c7e0331e86756dfc7f --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/embeddings.py @@ -0,0 +1,114 @@ +import torch +from torch import nn +from typing import List +from diffusers.models.embeddings import Timesteps, TimestepEmbedding + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, latent): + latent = self.proj(latent) + return latent + +class PooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, pooled_embed): + return self.pooled_embedder(pooled_embed) + +class TimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, timesteps, wdtype): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + +class OutEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, adaln_input): + shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) + x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = self.linear(x) + return x \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/moe.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3b6ce251dc38e4b00537655bfb1d17141f17d7 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/moe.py @@ -0,0 +1,157 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from .attention import FeedForwardSwiGLU +from torch.distributed.nn.functional import all_gather + +_LOAD_BALANCING_LOSS = [] +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + +def batched_load_balancing_loss(): + aux_losses_arr = get_load_balancing_loss() + alpha = aux_losses_arr[0][-1] + Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0) + fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0) + + fi_list = all_gather(fi) + fi = torch.stack(fi_list, 0).mean(0) + + aux_loss = (Pi * fi).sum(-1).mean() * alpha + return aux_loss + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = 'softmax' + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # print(bsz, seq_len, h) + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + # this was in original and memory leaks, not needed + + # ### expert-level computation auxiliary loss + # if self.training and self.alpha > 0.0: + # scores_for_aux = scores + # aux_topk = self.top_k + # # always compute aux loss based on the naive greedy topk method + # topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + # if self.seq_aux: + # scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + # ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + # ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts) + # aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha + # else: + # mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + # ce = mask_ce.float().mean(0) + + # Pi = scores_for_aux.mean(0) + # fi = ce * self.n_routed_experts + # aux_loss = (Pi * fi).sum() * self.alpha + # save_load_balancing_loss((aux_loss, Pi, fi, self.alpha)) + # else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + ): + super().__init__() + self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2) + self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]) + self.gate = MoEGate( + embed_dim = dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + # this was in original and memory leaks, not needed + # if self.training: + # x = x.repeat_interleave(self.num_activated_experts, dim=0) + # y = torch.empty_like(x, dtype=wtype) + # for i, expert in enumerate(self.experts): + # y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + # y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + # y = y.view(*orig_shape).to(dtype=wtype) + # #y = AddAuxiliaryLoss.apply(y, aux_loss) + # else: + # y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + # @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + return expert_cache diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f7eb1045ce24faadac7b3fff059c8bd6b460ed22 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/models/transformers/transformer_hidream_image.py @@ -0,0 +1,506 @@ +from typing import Any, Callable, Dict, Optional, Tuple, List + +import torch +import torch.nn as nn +import einops +from einops import repeat + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed +from ..attention import HiDreamAttention, FeedForwardSwiGLU +from ..attention_processor import HiDreamAttnProcessor_flashattn +from ..moe import MOEFeedForwardSwiGLU + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + +class BlockType: + TransformerBlock = 1 + SingleTransformerBlock = 2 + +@maybe_allow_in_graph +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True) + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = True + ) + + # 3. Feed-forward + self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_image_tokens, + image_tokens_masks, + rope = rope, + ) + image_tokens = gate_msa_i * attn_output_i + image_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) + image_tokens = ff_output_i + image_tokens + return image_tokens + +@maybe_allow_in_graph +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 12 * dim, bias=True) + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = False + ) + + # 3. Feed-forward + self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) + self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ + shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_image_tokens, + image_tokens_masks, + norm_text_tokens, + rope = rope, + ) + + image_tokens = gate_msa_i * attn_output_i + image_tokens + text_tokens = gate_msa_t * attn_output_t + text_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) + ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) + image_tokens = ff_output_i + image_tokens + text_tokens = ff_output_t + text_tokens + return image_tokens, text_tokens + +@maybe_allow_in_graph +class HiDreamImageBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + block_type: BlockType = BlockType.TransformerBlock, + ): + super().__init__() + block_classes = { + BlockType.TransformerBlock: HiDreamImageTransformerBlock, + BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, + } + self.block = block_classes[block_type]( + dim, + num_attention_heads, + attention_head_dim, + num_routed_experts, + num_activated_experts + ) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + return self.block( + image_tokens, + image_tokens_masks, + text_tokens, + adaln_input, + rope, + ) + +class HiDreamImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): + _supports_gradient_checkpointing = True + _no_split_modules = ["HiDreamImageBlock"] + + @register_to_config + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = TimestepEmbed(self.inner_dim) + self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim) + self.x_embedder = PatchEmbed( + patch_size = patch_size, + in_channels = in_channels, + out_channels = self.inner_dim, + ) + self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.config.num_attention_heads, + attention_head_dim = self.config.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.TransformerBlock + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.config.num_attention_heads, + attention_head_dim = self.config.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.SingleTransformerBlock + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels) + + caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + self.gradient_checkpointing = False + + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + # the implementation on hidream during train was wrong, just use the inference one. + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: + # Process all images in the batch according to their specific dimensions + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange( + x[i, :pH*pW].reshape(1, pH, pW, -1), + 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.config.patch_size, p2=self.config.patch_size + ) + ) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.config.patch_size * self.config.patch_size + if isinstance(x, torch.Tensor): + B, C = x.shape[0], x.shape[1] + device = x.device + dtype = x.dtype + else: + B, C = len(x), x[0].shape[0] + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0:img_size[0] * img_size[1]] = 1 + x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) + elif isinstance(x, torch.Tensor): + pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size + x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + hidden_states: torch.Tensor, + timesteps: torch.LongTensor = None, + encoder_hidden_states: torch.Tensor = None, + pooled_embeds: torch.Tensor = None, + img_sizes: Optional[List[Tuple[int, int]]] = None, + img_ids: Optional[torch.Tensor] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states[-1] + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach() + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + image_tokens_masks, + cur_encoder_hidden_states, + adaln_input.clone(), + rope.clone(), + ) + + else: + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach() + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + image_tokens_masks, + None, + adaln_input.clone(), + rope.clone(), + ) + else: + hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = None, + adaln_input = adaln_input, + rope = rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes, self.training) + if image_tokens_masks is not None: + image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len] + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, image_tokens_masks) + return Transformer2DModelOutput(sample=output, mask=image_tokens_masks) + diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4e51fa359962be328e5eaa51fecd7a80a0b837 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image.py @@ -0,0 +1,737 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import math +import einops +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast +) + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from .pipeline_output import HiDreamImagePipelineOutput +from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel +from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + aggressive_unloading: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + self.aggressive_unloading = aggressive_unloading + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True + ) + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + _, _, seq_len, dim = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + prompt_4: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] + + prompt_embeds, pooled_prompt_embeds = self._encode_prompt( + prompt = prompt, + prompt_2 = prompt_2, + prompt_3 = prompt_3, + prompt_4 = prompt_4, + device = device, + dtype = dtype, + num_images_per_prompt = num_images_per_prompt, + prompt_embeds = prompt_embeds, + pooled_prompt_embeds = pooled_prompt_embeds, + max_sequence_length = max_sequence_length, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_4 = negative_prompt_4 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + negative_prompt_4 = ( + batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( + prompt = negative_prompt, + prompt_2 = negative_prompt_2, + prompt_3 = negative_prompt_3, + prompt_4 = negative_prompt_4, + device = device, + dtype = dtype, + num_images_per_prompt = num_images_per_prompt, + prompt_embeds = negative_prompt_embeds, + pooled_prompt_embeds = negative_pooled_prompt_embeds, + max_sequence_length = max_sequence_length, + ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + prompt_4: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + ): + device = device or self._execution_device + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, + self.text_encoder, + prompt = prompt, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype, + ) + + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, + self.text_encoder_2, + prompt = prompt_2, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype, + ) + + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + t5_prompt_embeds = self._get_t5_prompt_embeds( + prompt = prompt_3, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype + ) + llama3_prompt_embeds = self._get_llama3_prompt_embeds( + prompt = prompt_4, + num_images_per_prompt = num_images_per_prompt, + max_sequence_length = max_sequence_length, + device = device, + dtype = dtype + ) + prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] + + return prompt_embeds, pooled_prompt_embeds + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + division = self.vae_scale_factor * 2 + S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 + scale = S_max / (width * height) + scale = math.sqrt(scale) + width, height = int(width * scale // division * division), int(height * scale // division * division) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds_arr = [] + for n, p in zip(negative_prompt_embeds, prompt_embeds): + if len(n.shape) == 3: + prompt_embeds_arr.append(torch.cat([n, p], dim=0)) + else: + prompt_embeds_arr.append(torch.cat([n, p], dim=1)) + prompt_embeds = prompt_embeds_arr + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + device, + generator, + latents, + ) + + if latents.shape[-2] != latents.shape[-1]: + B, C, H, W = latents.shape + pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size + + img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) + img_ids = torch.zeros(pH, pW, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] + img_ids = img_ids.reshape(pH * pW, -1) + img_ids_pad = torch.zeros(self.transformer.max_seq, 3) + img_ids_pad[:pH*pW, :] = img_ids + + img_sizes = img_sizes.unsqueeze(0).to(latents.device) + img_ids = img_ids_pad.unsqueeze(0).to(latents.device) + if self.do_classifier_free_guidance: + img_sizes = img_sizes.repeat(2 * B, 1) + img_ids = img_ids.repeat(2 * B, 1, 1) + else: + img_sizes = img_ids = None + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if latent_model_input.shape[-2] != latent_model_input.shape[-1]: + B, C, H, W = latent_model_input.shape + patch_size = self.transformer.config.patch_size + pH, pW = H // patch_size, W // patch_size + out = torch.zeros( + (B, C, self.transformer.max_seq, patch_size * patch_size), + dtype=latent_model_input.dtype, + device=latent_model_input.device + ) + latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) + out[:, :, 0:pH*pW] = latent_model_input + latent_model_input = out + + noise_pred = self.transformer( + hidden_states = latent_model_input, + timesteps = timestep, + encoder_hidden_states = prompt_embeds, + pooled_embeds = pooled_prompt_embeds, + img_sizes = img_sizes, + img_ids = img_ids, + return_dict = False, + )[0] + noise_pred = -noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image_editing.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image_editing.py new file mode 100644 index 0000000000000000000000000000000000000000..9afd36ac38622e34dd1b7b150b0b9d8c0d2ca6d5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_hidream_image_editing.py @@ -0,0 +1,1206 @@ +# ref https://github.com/HiDream-ai/HiDream-E1/blob/main/pipeline_hidream_image_editing.py +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import PIL + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast, + T5EncoderModel, + T5Tokenizer, +) + +from diffusers.image_processor import VaeImageProcessor, PipelineImageInput +from diffusers.loaders import HiDreamImageLoraLoaderMixin +from diffusers.models import AutoencoderKL, HiDreamImageTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from diffusers.utils import deprecate, is_torch_xla_available, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.hidream_image.pipeline_output import HiDreamImagePipelineOutput +import logging + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() # Ensure output goes to console + ] +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM + >>> from diffusers import UniPCMultistepScheduler + >>> from pipeline_hidream_image_editing import HiDreamImageEditingPipeline + >>> from PIL import Image + + + >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( + ... "meta-llama/Meta-Llama-3.1-8B-Instruct", + ... output_hidden_states=True, + ... output_attentions=True, + ... torch_dtype=torch.bfloat16, + ... ) + + >>> pipe = HiDreamImageEditingPipeline.from_pretrained( + ... "HiDream-ai/HiDream-E1-Full", + ... tokenizer_4=tokenizer_4, + ... text_encoder_4=text_encoder_4, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # Load input image for editing + >>> input_image = Image.open("your_image.jpg") + >>> input_image = input_image.resize((768, 768)) + + >>> # Edit the image based on instructions + >>> image = pipe( + ... prompt='Editing Instruction: Convert the image into a Ghibli style. Target Image Description: A person in a light pink t-shirt with short dark hair, depicted in a Ghibli style against a plain background.', + ... negative_prompt="low resolution, blur", + ... image=input_image, + ... guidance_scale=5.0, + ... image_guidance_scale=4.0, + ... num_inference_steps=28, + ... generator=torch.Generator("cuda").manual_seed(3), + ... ).images[0] + >>> image.save("edited_output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HiDreamImageEditingPipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + aggressive_unloading: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + if getattr(self, "tokenizer_4", None) is not None: + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + + + self.aggressive_unloading = aggressive_unloading + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_4.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True, + ) + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + return prompt_embeds + + def encode_prompt( + self, + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None, + prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = pooled_prompt_embeds.shape[0] + + device = device or self._execution_device + + if pooled_prompt_embeds is None: + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype + ) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if len(negative_prompt) > 1 and len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if len(prompt_2) > 1 and len(prompt_2) != batch_size: + raise ValueError(f"prompt_2 must be of length 1 or {batch_size}") + + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype + ) + + if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + + if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size: + raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1 + ) + + if prompt_embeds_t5 is None: + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + if len(prompt_3) > 1 and len(prompt_3) != batch_size: + raise ValueError(f"prompt_3 must be of length 1 or {batch_size}") + + prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) + + if prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_t5 is None: + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + + if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size: + raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}") + + negative_prompt_embeds_t5 = self._get_t5_prompt_embeds( + negative_prompt_3, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + + if prompt_embeds_llama3 is None: + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + if len(prompt_4) > 1 and len(prompt_4) != batch_size: + raise ValueError(f"prompt_4 must be of length 1 or {batch_size}") + + prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) + + if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None: + negative_prompt_4 = negative_prompt_4 or negative_prompt + negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + + if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size: + raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}") + + negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds( + negative_prompt_4, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + # duplicate pooled_prompt_embeds for each generation per prompt + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}") + prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}") + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + + if do_classifier_free_guidance: + # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len = negative_pooled_prompt_embeds.shape + if bs_embed == 1 and batch_size > 1: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}") + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}") + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}") + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view( + -1, batch_size * num_images_per_prompt, seq_len, dim + ) + + return ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + negative_prompt_4=None, + prompt_embeds_t5=None, + prompt_embeds_llama3=None, + negative_prompt_embeds_t5=None, + negative_prompt_embeds_llama3=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to" + " only forward one of the two." + ) + elif prompt_4 is not None and prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and pooled_prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_t5 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined." + ) + elif prompt is None and prompt_embeds_llama3 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)): + raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}") + + if negative_prompt is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:" + f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two." + ) + elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:" + f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two." + ) + + if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None: + if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape: + raise ValueError( + "`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but" + f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`" + f" {negative_pooled_prompt_embeds.shape}." + ) + if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None: + if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape: + raise ValueError( + "`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but" + f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`" + f" {negative_prompt_embeds_t5.shape}." + ) + if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None: + if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape: + raise ValueError( + "`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but" + f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`" + f" {negative_prompt_embeds_llama3.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([uncond_image_latents, image_latents, image_latents], dim=0) + + return image_latents + + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + image_guidance_scale: float = 2.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds_t5: Optional[torch.FloatTensor] = None, + prompt_embeds_llama3: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + refine_strength: float = 0.0, + reload_keys: Any = None, + refiner: HiDreamImageTransformer2DModel = None, + clip_cfg_norm: bool = True, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead. + prompt_4 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_4` and `text_encoder_4`. If not defined, `prompt` is + will be used instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_4 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_4` and + `text_encoder_4`. If not defined, `negative_prompt` is used in all the text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] or `tuple`: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated. images. + """ + + prompt_embeds = kwargs.get("prompt_embeds", None) + negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) + + if prompt_embeds is not None: + deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead." + deprecate("prompt_embeds", "0.35.0", deprecation_message) + prompt_embeds_t5 = prompt_embeds[0] + prompt_embeds_llama3 = prompt_embeds[1] + + if negative_prompt_embeds is not None: + deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead." + deprecate("negative_prompt_embeds", "0.35.0", deprecation_message) + negative_prompt_embeds_t5 = negative_prompt_embeds[0] + negative_prompt_embeds_llama3 = negative_prompt_embeds[1] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._image_guidance_scale = image_guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif pooled_prompt_embeds is not None: + batch_size = pooled_prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode prompt + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if prompt is not None and "Target Image Description:" in prompt: + target_prompt = prompt.split("Target Image Description:")[1].strip() + ( + target_prompt_embeds_t5, + target_negative_prompt_embeds_t5, + target_prompt_embeds_llama3, + target_negative_prompt_embeds_llama3, + target_pooled_prompt_embeds, + target_negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=target_prompt, + prompt_2=None, + prompt_3=None, + prompt_4=None, + negative_prompt=negative_prompt, + negative_prompt_2=None, + negative_prompt_3=None, + negative_prompt_4=None, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds_t5=None, + prompt_embeds_llama3=None, + negative_prompt_embeds_t5=None, + negative_prompt_embeds_llama3=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + else: + target_prompt_embeds_t5 = prompt_embeds_t5 + target_negative_prompt_embeds_t5 = negative_prompt_embeds_t5 + target_prompt_embeds_llama3 = prompt_embeds_llama3 + target_negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3 + target_pooled_prompt_embeds = pooled_prompt_embeds + target_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds + + image = self.image_processor.preprocess(image) + + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + pooled_prompt_embeds.dtype, + device, + self.do_classifier_free_guidance, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + if self.do_classifier_free_guidance: + if clip_cfg_norm: + prompt_embeds_t5 = torch.cat([prompt_embeds_t5, negative_prompt_embeds_t5, prompt_embeds_t5], dim=0) + prompt_embeds_llama3 = torch.cat([prompt_embeds_llama3, negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, negative_prompt_embeds_t5, prompt_embeds_t5], dim=0) + prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + target_prompt_embeds_t5 = torch.cat([target_negative_prompt_embeds_t5, target_prompt_embeds_t5], dim=0) + target_prompt_embeds_llama3 = torch.cat([target_negative_prompt_embeds_llama3, target_prompt_embeds_llama3], dim=1) + target_pooled_prompt_embeds = torch.cat([target_negative_pooled_prompt_embeds, target_pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if isinstance(self.scheduler, UniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + # 6. Denoising loop + refine_stage = False + if reload_keys is not None: + logger.info(f"loading editing keys") + load_info = self.transformer.load_state_dict(reload_keys['editing'], strict=False) + logger.info(f"finished loading editing keys") + assert len(load_info.unexpected_keys) == 0 + try: + self.transformer.enable_adapters() + except Exception as e: + pass + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # === STAGE DETERMINATION === + # Check if we need to switch from editing stage to refining stage + if reload_keys is not None and i == int(num_inference_steps * (1.0 - refine_strength)): + # Switch from editing to refining stage + try: + self.transformer.disable_adapters() + except Exception as e: + pass + logger.info(f"loading refine keys") + load_info = self.transformer.load_state_dict(reload_keys['refine'], strict=False) + logger.info(f"finished loading refine keys") + assert len(load_info.unexpected_keys) == 0 + logger.info(f"Refining start at step {i}") + refine_stage = True + + if self.interrupt: + continue + + # === INPUT PREPARATION === + if refine_stage: + # Refining stage: Use target prompts and simpler input (no image conditioning) + latent_model_input_with_condition = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + current_prompt_embeds_t5 = target_prompt_embeds_t5 + current_prompt_embeds_llama3 = target_prompt_embeds_llama3 + current_pooled_prompt_embeds = target_pooled_prompt_embeds + else: + # Editing stage: Use original prompts and include image conditioning + latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents + latent_model_input_with_condition = torch.cat([latent_model_input, image_latents], dim=-1) + current_prompt_embeds_t5 = prompt_embeds_t5 + current_prompt_embeds_llama3 = prompt_embeds_llama3 + current_pooled_prompt_embeds = pooled_prompt_embeds + + # === TRANSFORMER SELECTION === + # Choose which transformer to use for this step + if refine_stage and refiner is not None: + transformer_func = refiner + else: + transformer_func = self.transformer + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input_with_condition.shape[0]) + noise_pred = transformer_func( + hidden_states=latent_model_input_with_condition, + timesteps=timestep, + encoder_hidden_states_t5=current_prompt_embeds_t5, + encoder_hidden_states_llama3=current_prompt_embeds_llama3, + pooled_embeds=current_pooled_prompt_embeds, + return_dict=False, + )[0] + # perform guidance + noise_pred = -1.0 * noise_pred[..., :latents.shape[-1]] + if self.do_classifier_free_guidance: + if refine_stage: + uncond, full_cond = noise_pred.chunk(2) + noise_pred = uncond + self.guidance_scale * (full_cond - uncond) + else: + if clip_cfg_norm: + uncond, image_cond, full_cond = noise_pred.chunk(3) + pred_text_ = image_cond + self.guidance_scale * (full_cond - image_cond) + norm_full_cond = torch.norm(full_cond, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + pred_text = pred_text_ * scale + noise_pred = uncond + self.image_guidance_scale * (pred_text - uncond) + else: + uncond, image_cond, full_cond = noise_pred.chunk(3) + noise_pred = uncond + self.image_guidance_scale * (image_cond - uncond) + self.guidance_scale * ( + full_cond - image_cond) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + current_prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", current_prompt_embeds_t5) + current_prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", current_prompt_embeds_llama3) + current_pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", current_pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_output.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..b03e2faea5802480e3de9943e890d760de4cb835 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/pipelines/hidream_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class HiDreamImagePipelineOutput(BaseOutput): + """ + Output class for HiDreamImage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/schedulers/flash_flow_match.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/schedulers/flash_flow_match.py new file mode 100644 index 0000000000000000000000000000000000000000..122f8ed74464473757a1f85feaa3f0b132d28d13 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/schedulers/flash_flow_match.py @@ -0,0 +1,428 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, is_scipy_available, logging +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + import scipy.stats + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlashFlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + self.num_inference_steps = num_inference_steps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlashFlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + + sigma = self.sigmas[self.step_index] + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + denoised = sample - model_output * sigma + + if self.step_index < self.num_inference_steps - 1: + sigma_next = self.sigmas[self.step_index + 1] + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=denoised.dtype, + ) + sample = sigma_next * noise + (1.0 - sigma_next) * denoised + + self._step_index += 1 + sample = sample.to(model_output.dtype) + + if not return_dict: + return (sample,) + + return FlashFlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def __len__(self): + return self.config.num_train_timesteps diff --git a/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/schedulers/fm_solvers_unipc.py b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/schedulers/fm_solvers_unipc.py new file mode 100644 index 0000000000000000000000000000000000000000..57321baa35359782b33143321cd31c8d934a7b29 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/hidream/src/schedulers/fm_solvers_unipc.py @@ -0,0 +1,800 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e26a32d8ee92feb2cd706f791e97b99a65855289 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/__init__.py @@ -0,0 +1 @@ +from .ideogram4 import Ideogram4Model diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/ideogram4.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/ideogram4.py new file mode 100644 index 0000000000000000000000000000000000000000..e42a4a2c8eed1e879dfbdea2d1371f1144247ae8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/ideogram4.py @@ -0,0 +1,635 @@ +import os +from typing import List, Optional + +import torch +import yaml +from safetensors.torch import load_file, save_file + +from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig +from toolkit.models.base_model import BaseModel +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.basic import flush +from toolkit.print import print_acc +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.ideogram_caption import digest_caption_string +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from toolkit.metadata import get_meta_for_safetensors +from toolkit.memory_management import MemoryManager +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from optimum.quanto import freeze, QTensor + +import huggingface_hub +from huggingface_hub.errors import EntryNotFoundError +from transformers import AutoModel, AutoTokenizer + +from .src.transformer import Ideogram4Config, Ideogram4Transformer2DModel +from .src.vae import AutoEncoder, AutoEncoderParams, convert_diffusers_state_dict +from .src.latent_norm import get_latent_norm +from .src.pipeline import ( + Ideogram4Pipeline, + get_qwen3_vl_features, + pad_text_features, + patchify_latents, + predict_velocity, + unpatchify_latents, +) + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} + +# Weight-only FP8 (e4m3) Linear weights carry a per-output-channel float32 scale +# saved alongside as ``.weight_scale``. Folding it back gives bf16 weights. +FP8_SCALE_SUFFIX = ".weight_scale" + +# The text encoder is frozen, stock Qwen3-VL-8B-Instruct. +QWEN3_VL_PATH = "Qwen/Qwen3-VL-8B-Instruct" + +HF_TOKEN = os.getenv("HF_TOKEN", None) + + +def _dequantize_fp8_state_dict( + state_dict: dict, + dtype: torch.dtype, + device: torch.device, + low_vram: bool, +) -> dict: + """Fold weight-only FP8 scales back into the weights, casting to ``dtype``. + + Linear weights stored as float8 with a sibling ``.weight_scale`` are + reconstructed as ``weight_fp8.to(float32) * scale[:, None]``. Everything else + is simply cast to ``dtype`` (non-floating tensors are left untouched). If the + checkpoint isn't quantized this is just a dtype cast. + + The fold/cast runs on ``device`` (GPU is much faster than CPU). With + ``low_vram=True`` each tensor is moved to ``device``, processed, then moved + back to CPU so the whole bf16 model never sits on the GPU at once; otherwise + the dequantized tensors are left on ``device`` ready to load. + """ + work_device = torch.device(device) + + def _finish(t: torch.Tensor) -> torch.Tensor: + return t.to("cpu") if low_vram else t + + num_fp8 = sum(1 for k in state_dict if k.endswith(FP8_SCALE_SUFFIX)) + if num_fp8 > 0: + print_acc(f" dequantizing {num_fp8} fp8 weights -> {dtype} on {work_device}") + else: + print_acc(f" casting weights -> {dtype} on {work_device}") + + out = {} + for key, tensor in state_dict.items(): + if key.endswith(FP8_SCALE_SUFFIX): + continue + scale_key = key + "_scale" + if key.endswith(".weight") and scale_key in state_dict: + w = tensor.to(work_device, torch.float32) + scale = state_dict[scale_key].to(work_device, torch.float32) + out[key] = _finish((w * scale.unsqueeze(1)).to(dtype)) + elif tensor.is_floating_point(): + out[key] = _finish(tensor.to(work_device, dtype)) + else: + out[key] = tensor + return out + + +def _load_component_state_dict(base: str, subfolder: str, basename: str) -> dict: + """Load a component's weights whether local or on the hub, sharded or single.""" + index_name = f"{basename}.safetensors.index.json" + single_name = f"{basename}.safetensors" + + # Local directory layout: // + local_dir = os.path.join(base, subfolder) + if os.path.isdir(local_dir): + index_path = os.path.join(local_dir, index_name) + if os.path.exists(index_path): + return _load_sharded(local_dir, index_path, is_local=True) + return load_file(os.path.join(local_dir, single_name)) + + # Hub repo layout: / + prefix = f"{subfolder}/" if subfolder else "" + try: + index_path = huggingface_hub.hf_hub_download( + repo_id=base, filename=f"{prefix}{index_name}", token=HF_TOKEN + ) + return _load_sharded(base, index_path, is_local=False, prefix=prefix) + except EntryNotFoundError: + single_path = huggingface_hub.hf_hub_download( + repo_id=base, filename=f"{prefix}{single_name}", token=HF_TOKEN + ) + return load_file(single_path) + + +def _load_sharded(base, index_path, is_local, prefix="") -> dict: + import json + + with open(index_path) as f: + index = json.load(f) + shard_files = sorted(set(index["weight_map"].values())) + state_dict = {} + num_shards = len(shard_files) + for i, shard in enumerate(shard_files): + if is_local: + shard_path = os.path.join(base, shard) + else: + print_acc(f" downloading shard {i + 1}/{num_shards}: {shard}") + shard_path = huggingface_hub.hf_hub_download( + repo_id=base, filename=f"{prefix}{shard}", token=HF_TOKEN + ) + print_acc(f" loading shard {i + 1}/{num_shards}: {shard}") + state_dict.update(load_file(shard_path)) + return state_dict + + +class Ideogram4Model(BaseModel): + arch = "ideogram4" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.use_old_lokr_format = False + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["Ideogram4Transformer2DModel"] + + self.patch_size = 2 + self.vae_scale_factor = 8 + # Safety cap on caption token length (truncation only). Captions are stored + # per-sample at their natural length and padded to the batch max at the + # model call, so this is just an upper bound for very long JSON prompts. + self.max_text_length = int( + self.model_config.model_kwargs.get("max_text_length", 3072) + ) + + self._latent_shift = None + self._latent_scale = None + + # Optional LoRA that is only switched on during the unconditional (negative) + # CFG pass. Loaded from model_config.unconditional_lora_path if set; stays + # inactive everywhere else (training, conditional pass). + self.unconditional_lora: Optional[LoRASpecialNetwork] = None + + @property + def text_embedding_space_version(self): + # we changed the embeddings. invalidate cache. + return self.arch + "_te_v2" + + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # 8 for the VAE downsample, 2 for the patch size. + return self.vae_scale_factor * self.patch_size + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + def _load_text_encoder(self, base: str): + dtype = self.torch_dtype + # The text encoder is frozen, stock Qwen3-VL-8B-Instruct. The ideogram repo + # only ships an fp8 copy of it, so load the public bf16 model directly -- + # faster and higher precision than dequantizing the fp8 weights. + te_path = self.model_config.model_kwargs.get("text_encoder_path", QWEN3_VL_PATH) + self.print_and_status_update(f"Loading Qwen3-VL text encoder from {te_path}") + + tokenizer = AutoTokenizer.from_pretrained(te_path, token=HF_TOKEN) + text_encoder = AutoModel.from_pretrained( + te_path, torch_dtype=dtype, token=HF_TOKEN + ) + flush() + + text_encoder.eval() + text_encoder.requires_grad_(False) + return tokenizer, text_encoder + + def _load_transformer(self, base: str): + dtype = self.torch_dtype + self.print_and_status_update("Loading transformer") + + transformer_config = Ideogram4Config() + with torch.device("meta"): + transformer = Ideogram4Transformer2DModel(transformer_config) + + self.print_and_status_update(" - fetching transformer weights") + state_dict = _load_component_state_dict( + base, "transformer", "diffusion_pytorch_model" + ) + self.print_and_status_update(" - dequantizing transformer weights") + state_dict = _dequantize_fp8_state_dict( + state_dict, dtype, self.device_torch, self.model_config.low_vram + ) + self.print_and_status_update(" - loading transformer state dict") + transformer.load_state_dict(state_dict, assign=True) + del state_dict + flush() + + # inv_freq is a non-persistent buffer absent from the checkpoint; rebuild + # it now that the module is off the meta device. + head_dim = transformer_config.emb_dim // transformer_config.num_heads + inv_freq = 1.0 / ( + transformer_config.rope_theta + ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + transformer.rotary_emb.register_buffer("inv_freq", inv_freq, persistent=False) + return transformer + + def _load_vae(self, base: str): + dtype = self.torch_dtype + self.print_and_status_update("Loading VAE") + vae_sd = _load_component_state_dict(base, "vae", "diffusion_pytorch_model") + vae_sd = convert_diffusers_state_dict(vae_sd) + vae = AutoEncoder(AutoEncoderParams()) + vae.load_state_dict(vae_sd) + del vae_sd + vae.to(self.vae_device_torch, dtype=dtype) + vae.eval() + vae.requires_grad_(False) + return vae + + def load_unconditional_lora(self, transformer: Ideogram4Transformer2DModel): + """Load the unconditional-pass LoRA and leave it applied but inactive. + + The adapter is wired into the transformer via ``apply_to`` (no merge) so + the pipeline can flip ``is_active`` on for the unconditional CFG pass only. + It never affects the conditional pass or training, where it stays inactive. + """ + lora_path = self.model_config.unconditional_lora_path + self.print_and_status_update(f"Loading unconditional LoRA from {lora_path}") + + if not os.path.exists(lora_path): + # assume it is a "repo/owner/filename.safetensors" hub path + lora_splits = lora_path.split("/") + if len(lora_splits) != 3: + raise ValueError( + f"Unconditional LoRA path {lora_path} is not a valid local path " + "or hub path." + ) + repo_id = "/".join(lora_splits[:2]) + filename = lora_splits[2] + try: + lora_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, filename=filename, token=HF_TOKEN + ) + self.model_config.unconditional_lora_path = lora_path + except Exception as e: + raise ValueError( + f"Failed to download unconditional LoRA from {lora_path}: {e}" + ) + + # Detect the LoRA rank from the first down-projection weight in the file. + lora_state_dict = load_file(lora_path) + lora_dim = None + for key, value in lora_state_dict.items(): + if key.endswith("lora_A.weight") or key.endswith("lora_down.weight"): + lora_dim = int(value.shape[0]) + break + if lora_dim is None: + raise ValueError( + f"Could not determine LoRA rank from {lora_path}: no lora_A/lora_down " + "weights found." + ) + + # transformer_only=False so every nn.Linear in the model is targeted (not + # just the transformer blocks) -- the extraction script factors all linears, + # so the adapter must wrap all of them to load every key. + network_config = NetworkConfig( + type="lora", + linear=lora_dim, + linear_alpha=lora_dim, + transformer_only=False, + ) + network = LoRASpecialNetwork( + text_encoder=None, + unet=transformer, + lora_dim=lora_dim, + multiplier=1.0, + alpha=lora_dim, + # train_unet just gates module creation here; the network is applied, + # kept inactive, and never trained (the pipeline only toggles is_active). + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type="lora", + transformer_only=False, + is_transformer=True, + target_lin_modules=self.target_lora_modules, + # base_model_ref lets load_weights run convert_lora_weights_before_load + # so saved "diffusion_model." keys map back to "transformer.". + base_model=self, + ) + network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + network.force_to(self.device_torch, dtype=self.torch_dtype) + network._update_torch_multiplier() + network.load_weights(lora_path) + network.eval() + + # Inactive by default; the pipeline flips this on only for the uncond pass. + network.is_active = False + self.unconditional_lora = network + self.print_and_status_update("Unconditional LoRA loaded (inactive)") + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Ideogram4 model") + base = self.model_config.name_or_path + + transformer = self._load_transformer(base) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + transformer.rotary_emb.inv_freq, + transformer.input_proj, + transformer.llm_cond_proj, + ], + ) + elif self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + else: + # quantize_model leaves the model on CPU; make sure it lands on device. + transformer.to(self.device_torch) + flush() + + tokenizer, text_encoder = self._load_text_encoder(base) + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + text_encoder.to(self.device_torch) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + elif self.model_config.low_vram: + self.print_and_status_update("Moving text encoder to CPU") + text_encoder.to("cpu") + else: + self.print_and_status_update("Moving text encoder to device") + text_encoder.to(self.device_torch) + flush() + + vae = self._load_vae(base) + + self.noise_scheduler = Ideogram4Model.get_train_scheduler() + + shift, scale = get_latent_norm() + self._latent_shift = shift.view(1, -1, 1, 1) + self._latent_scale = scale.view(1, -1, 1, 1) + + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.model = transformer + self.pipeline = Ideogram4Pipeline(self) + + if self.model_config.unconditional_lora_path is not None: + self.load_unconditional_lora(transformer) + + self.print_and_status_update("Model Loaded") + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + def get_generation_pipeline(self): + return Ideogram4Pipeline(self) + + def generate_single_image( + self, + pipeline: Ideogram4Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + )[0] + return img + + # ------------------------------------------------------------------ + # Training hooks + # ------------------------------------------------------------------ + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, # (B, 128, gh, gw) + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: AdvancedPromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + t01 = timestep.to(self.device_torch, dtype=torch.float32) / 1000.0 + if t01.dim() == 0: + t01 = t01.unsqueeze(0) + if t01.shape[0] != latent_model_input.shape[0]: + t01 = t01.expand(latent_model_input.shape[0]) + + # Pad the per-sample caption features to the batch max here. + llm_features, text_mask = pad_text_features( + text_embeddings.text_embeds, self.device_torch, self.torch_dtype + ) + + pred = predict_velocity( + self.transformer, + latent_model_input.to(self.device_torch), + t01, + llm_features, + text_mask, + ) + return pred + + def get_prompt_embeds(self, prompt) -> AdvancedPromptEmbeds: + if isinstance(prompt, str): + prompt = [prompt] + + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + device = self.text_encoder.device + + # Encode each caption at its natural length (no cross-sample padding) and + # store one feature tensor per batch item. Padding to a common length is + # deferred to the model call, so caching a prompt only stores its real + # length -- important for the long structured (JSON) captions. + features_list = [] + for p in prompt: + # Digest the prompt: migrate any old-format Ideogram caption into the + # current schema and serialize it compact (the form the renderer wants). + # Plain-text prompts pass straight through unchanged. + p = digest_caption_string(p) + messages = [{"role": "user", "content": [{"type": "text", "text": p}]}] + text = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + ids = self.tokenizer( + text, + add_special_tokens=False, + truncation=True, + max_length=self.max_text_length, + )["input_ids"] + if len(ids) == 0: + ids = [self.tokenizer.eos_token_id or 0] + + token_ids = torch.tensor([ids], dtype=torch.long, device=device) + attention_mask = torch.ones_like(token_ids) + pos_2d = (attention_mask.cumsum(dim=-1) - 1).clamp(min=0).to(torch.long) + + features = get_qwen3_vl_features( + self.text_encoder, token_ids, attention_mask, pos_2d + ) # (1, Lt, D) + features_list.append(features[0].to(self.torch_dtype)) + + return AdvancedPromptEmbeds(text_embeds=features_list) + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + # ------------------------------------------------------------------ + # VAE + # ------------------------------------------------------------------ + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(self.vae_device_torch) + + if isinstance(image_list, list): + images = torch.stack(image_list, dim=0) + else: + images = image_list + images = images.to(device, dtype=dtype) + + ae_channels = self.vae.params.z_channels + moments = self.vae.encoder(images) + mean = moments[:, :ae_channels] + + patched = patchify_latents(mean, self.patch_size) + shift = self._latent_shift.to(patched.device, patched.dtype) + scale = self._latent_scale.to(patched.device, patched.dtype) + latents = (patched - shift) / scale + return latents.to(device, dtype=dtype) + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + if self.vae.device == torch.device("cpu"): + self.vae.to(self.vae_device_torch) + + latents = latents.to(device, dtype=dtype) + shift = self._latent_shift.to(device, dtype) + scale = self._latent_scale.to(device, dtype) + patched = latents * scale + shift + z = unpatchify_latents(patched, self.patch_size) + images = self.vae.decoder(z) + return images + + # ------------------------------------------------------------------ + # Saving / misc + # ------------------------------------------------------------------ + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + transformer: Ideogram4Transformer2DModel = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + meta = get_meta_for_safetensors(meta, name="ideogram4") + save_file(save_dict, output_path, metadata=meta) + + def get_base_model_version(self): + return "ideogram4" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/latent_norm.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/latent_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..225c4ed5049b04dc62a7244c7f0f8a967f3164d5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/latent_norm.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import torch + +LATENT_SHIFT: tuple[float, ...] = ( + 0.01984364, + 0.10149707, + 0.29689495, + 0.27188619, + -0.21445648, + -0.15979549, + 0.05021099, + -0.15083604, + -0.15360136, + -0.20131799, + 0.01922352, + 0.0622626, + 0.10140969, + -0.06739428, + 0.3758261, + -0.233712, + 0.35164491, + -0.02590912, + -0.0271935, + -0.10833897, + -0.1476848, + -0.01130957, + -0.2298372, + 0.23526423, + -0.10893522, + 0.11957631, + 0.04047799, + 0.3134589, + -0.17225064, + -0.18646109, + -0.34691978, + -0.03571246, + 0.02583857, + 0.10190072, + 0.28402294, + 0.26952152, + -0.21634675, + -0.17938656, + 0.04358909, + -0.15007621, + -0.1548502, + -0.18971131, + 0.02710861, + 0.05609494, + 0.10697846, + -0.06854968, + 0.38167698, + -0.24269937, + 0.35705471, + -0.03063305, + -0.02946109, + -0.11244286, + -0.14336038, + -0.01362137, + -0.21863696, + 0.23228983, + -0.11739769, + 0.11693044, + 0.02563311, + 0.31356594, + -0.17420591, + -0.19006285, + -0.34905377, + -0.04025005, + 0.01924137, + 0.07652984, + 0.2995608, + 0.2628057, + -0.22011674, + -0.12715361, + 0.04879879, + -0.14075719, + -0.15935895, + -0.2123584, + 0.01974813, + 0.05523547, + 0.10011992, + -0.06428964, + 0.37781868, + -0.21491644, + 0.34254215, + -0.03153528, + -0.0310082, + -0.10761415, + -0.14730405, + -0.02475182, + -0.2285588, + 0.2515081, + -0.10445128, + 0.12446, + 0.07062869, + 0.30880162, + -0.18016875, + -0.18869164, + -0.34533499, + -0.0129177, + 0.02578168, + 0.07993659, + 0.28642181, + 0.26038408, + -0.22459419, + -0.14820155, + 0.04059549, + -0.14043529, + -0.16111187, + -0.2020305, + 0.02602069, + 0.04852717, + 0.10432153, + -0.06309942, + 0.38402443, + -0.22397003, + 0.34814481, + -0.03774432, + -0.03381438, + -0.11245691, + -0.14128767, + -0.02853208, + -0.21752016, + 0.24872463, + -0.11399775, + 0.1222687, + 0.05620835, + 0.309178, + -0.18065738, + -0.19401479, + -0.34495114, + -0.01760592, +) + +LATENT_SCALE: tuple[float, ...] = ( + 1.63933691, + 1.70204478, + 1.73642566, + 1.90004803, + 1.6675316, + 1.69059584, + 1.56853198, + 1.62314944, + 1.89106626, + 1.58086668, + 1.60822129, + 1.60962993, + 1.63322129, + 1.56074359, + 1.73419528, + 1.7919265, + 1.64040632, + 1.66802808, + 1.60390303, + 1.75480492, + 1.63187587, + 1.64334594, + 1.61722884, + 1.60146046, + 1.63459219, + 1.55291476, + 1.68771497, + 1.68415657, + 1.78966054, + 1.66631641, + 1.65626686, + 1.65976433, + 1.63487607, + 1.69513249, + 1.72933756, + 1.91310663, + 1.67035057, + 1.72286863, + 1.56719251, + 1.61934825, + 1.88628859, + 1.56911539, + 1.59455129, + 1.60829869, + 1.62470611, + 1.56052853, + 1.73677003, + 1.77563606, + 1.63732541, + 1.66370527, + 1.59508952, + 1.75153949, + 1.63029275, + 1.64517667, + 1.61659342, + 1.59722044, + 1.64103121, + 1.5408531, + 1.68610394, + 1.67772755, + 1.78998563, + 1.66621713, + 1.65458955, + 1.66041308, + 1.64710857, + 1.68163503, + 1.74000294, + 1.92784786, + 1.67411194, + 1.67395548, + 1.57406532, + 1.62199356, + 1.87618195, + 1.5584375, + 1.57438785, + 1.61711053, + 1.63094305, + 1.55644029, + 1.73124302, + 1.80666627, + 1.6463621, + 1.65932006, + 1.60816188, + 1.75682671, + 1.64695873, + 1.63121722, + 1.61380832, + 1.60478651, + 1.63396035, + 1.53505068, + 1.65534289, + 1.67132281, + 1.80317197, + 1.6767314, + 1.65700938, + 1.68426259, + 1.65339716, + 1.67540638, + 1.73298504, + 1.94067348, + 1.67893609, + 1.70635117, + 1.5730906, + 1.61928553, + 1.87148809, + 1.56244866, + 1.56697152, + 1.61584394, + 1.62759496, + 1.55480378, + 1.73484107, + 1.79055143, + 1.64688773, + 1.66121492, + 1.60135887, + 1.75254572, + 1.64798332, + 1.62989921, + 1.61381592, + 1.60792883, + 1.63939668, + 1.53075757, + 1.65371318, + 1.66801185, + 1.80029087, + 1.67591476, + 1.65655173, + 1.68533454, +) + + +def get_latent_norm() -> tuple[torch.Tensor, torch.Tensor]: + shift = torch.tensor(LATENT_SHIFT, dtype=torch.float32) + scale = torch.tensor(LATENT_SCALE, dtype=torch.float32) + assert shift.shape == (128,) and scale.shape == (128,) + return shift, scale diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c080dc84edcd1ba92311a201b36ac350a9a0321a --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/pipeline.py @@ -0,0 +1,409 @@ +"""Packing / sampling helpers for Ideogram 4. + +This module holds the glue that turns image latents + Qwen3-VL text features into +the single packed sequence the transformer consumes, plus a minimal flow-matching +sampling pipeline used to render preview images during training. +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +from PIL import Image +from diffusers.utils.torch_utils import randn_tensor + +from transformers.masking_utils import create_causal_mask + +from .transformer import ( + IMAGE_POSITION_OFFSET, + LLM_TOKEN_INDICATOR, + OUTPUT_IMAGE_INDICATOR, + QWEN3_VL_ACTIVATION_LAYERS, + SEQUENCE_PADDING_INDICATOR, + Ideogram4Transformer2DModel, +) + +_LOGSNR_MIN = -15.0 +_LOGSNR_MAX = 18.0 + + +def _logit_normal_schedule( + u: torch.Tensor, + mean: float, + std: float, +) -> torch.Tensor: + """Reference Ideogram time schedule, where 0 is noise and 1 is clean.""" + u = torch.as_tensor(u, dtype=torch.float64) + t = 1.0 - torch.special.expit(mean + std * torch.special.ndtri(u)) + t_min = 1.0 / (1.0 + math.exp(0.5 * _LOGSNR_MAX)) + t_max = 1.0 / (1.0 + math.exp(0.5 * _LOGSNR_MIN)) + return t.clamp(t_min, t_max) + + +def get_ideogram4_sigmas( + num_steps: int, + width: int, + height: int, + mu: float = 0.0, + std: float = 1.75, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Build the resolution-aware sigma schedule used by ComfyUI/Ideogram.""" + if num_steps < 1: + raise ValueError("num_steps must be at least 1") + if width < 1 or height < 1: + raise ValueError("width and height must be positive") + if std <= 0: + raise ValueError("std must be positive") + + mean = mu + 0.5 * math.log((width * height) / (512 * 512)) + u = torch.linspace(0.0, 1.0, num_steps + 1, dtype=torch.float64) + sigmas = (1.0 - _logit_normal_schedule(u, mean, std)).flip(0) + sigmas[-1] = 0.0 + return sigmas.to(device=device, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# Latent (un)patchification. +# +# The VAE produces (B, ae_ch=32, H/8, W/8) latents. The transformer works on +# tokens of dim ae_ch * patch**2 = 128. We store the patchified latent in a 4-D +# (B, 128, gh, gw) layout so the rest of ai-toolkit (noise, add_noise, loss) can +# treat it like an ordinary image latent. The channel ordering here matches the +# reference Ideogram 4 decode exactly: 128 = (patch_h, patch_w, ae_ch) with ae_ch +# the fastest-varying axis. +# --------------------------------------------------------------------------- + + +def patchify_latents(z: torch.Tensor, patch_size: int = 2) -> torch.Tensor: + """(B, ae_ch, H8, W8) -> (B, ae_ch * patch**2, gh, gw).""" + b, ae_ch, h8, w8 = z.shape + ph = pw = patch_size + gh, gw = h8 // ph, w8 // pw + z = z.view(b, ae_ch, gh, ph, gw, pw) + # -> (B, ph, pw, ae_ch, gh, gw) then merge (ph, pw, ae_ch) -> channels + z = z.permute(0, 3, 5, 1, 2, 4).reshape(b, ph * pw * ae_ch, gh, gw) + return z + + +def unpatchify_latents(z: torch.Tensor, patch_size: int = 2) -> torch.Tensor: + """(B, ae_ch * patch**2, gh, gw) -> (B, ae_ch, H8, W8).""" + b, c, gh, gw = z.shape + ph = pw = patch_size + ae_ch = c // (ph * pw) + z = z.view(b, ph, pw, ae_ch, gh, gw) + # -> (B, ae_ch, gh, ph, gw, pw) then merge spatial + z = z.permute(0, 3, 4, 1, 5, 2).reshape(b, ae_ch, gh * ph, gw * pw) + return z + + +# --------------------------------------------------------------------------- +# Qwen3-VL hidden-state extraction. +# --------------------------------------------------------------------------- + + +@torch.no_grad() +def get_qwen3_vl_features( + text_encoder, + token_ids: torch.Tensor, + attention_mask: torch.Tensor, + pos_2d: torch.Tensor, +) -> torch.Tensor: + """Run Qwen3-VL and concat the hidden states from the activation layers. + + Returns a (B, L, hidden_size * num_layers) tensor (in the encoder's dtype), + zeroed at non-text (padding) positions. + """ + language_model = text_encoder.language_model + + inputs_embeds = language_model.embed_tokens(token_ids) + + position_ids_4d = pos_2d[None, ...].expand(4, pos_2d.shape[0], -1) + text_position_ids = position_ids_4d[0] + mrope_position_ids = position_ids_4d[1:] + + causal_mask = create_causal_mask( + config=language_model.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=None, + position_ids=text_position_ids, + ) + position_embeddings = language_model.rotary_emb(inputs_embeds, mrope_position_ids) + + tap_set = set(QWEN3_VL_ACTIVATION_LAYERS) + captured: dict[int, torch.Tensor] = {} + hidden_states = inputs_embeds + for layer_idx, decoder_layer in enumerate(language_model.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=text_position_ids, + past_key_values=None, + position_embeddings=position_embeddings, + ) + if layer_idx in tap_set: + captured[layer_idx] = hidden_states + + selected = [captured[i] for i in QWEN3_VL_ACTIVATION_LAYERS] + batch_size, seq_len = token_ids.shape + stacked = torch.stack(selected, dim=0) # (num_taps, B, L, H) + stacked = torch.permute(stacked, (1, 2, 3, 0)) # (B, L, H, num_taps) + stacked = stacked.reshape(batch_size, seq_len, -1) + + text_mask = attention_mask.to(stacked.dtype).unsqueeze(-1) + stacked = stacked * text_mask + return stacked + + +# --------------------------------------------------------------------------- +# Packing + velocity prediction. +# --------------------------------------------------------------------------- + + +def pad_text_features( + features_list: List[torch.Tensor], + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Right-pad a list of per-sample (Lt_i, D) features into a batch. + + Captions are stored at their natural length (one tensor per batch item) and + only padded to the batch max here, right before the model call. Returns + ``(features (B, Lt, D), attention_mask (B, Lt))``; the mask is 1 for real + tokens and 0 for padding (which the transformer masks out anyway). + """ + lengths = [f.shape[0] for f in features_list] + max_len = max(lengths) + dim = features_list[0].shape[-1] + batch_size = len(features_list) + + features = torch.zeros(batch_size, max_len, dim, device=device, dtype=dtype) + mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) + for i, f in enumerate(features_list): + n = f.shape[0] + features[i, :n] = f.to(device, dtype) + mask[i, :n] = 1 + return features, mask + + +def predict_velocity( + transformer: Ideogram4Transformer2DModel, + latents: torch.Tensor, # (B, 128, gh, gw) + t: torch.Tensor, # (B,) toolkit flow time in [0, 1] (1 = pure noise) + llm_features: torch.Tensor, # (B, Lt, llm_dim) + text_mask: torch.Tensor, # (B, Lt) 1 for real text tokens +) -> torch.Tensor: + """Run the transformer on the packed [text | image] sequence. + + ``t`` is in the ai-toolkit flow-matching convention: ``t=1`` is pure noise, + ``t=0`` is clean, and the returned velocity is ``noise - clean`` (matching the + toolkit scheduler / loss target). + + Ideogram's transformer uses the opposite convention internally (``t=1`` is + clean) and predicts ``clean - noise``, so we feed it ``1 - t`` and negate its + output. Returns the velocity reshaped to the (B, 128, gh, gw) latent layout. + """ + device = latents.device + b, c, gh, gw = latents.shape + num_image_tokens = gh * gw + num_text_tokens = llm_features.shape[1] + seq_len = num_text_tokens + num_image_tokens + + # image latents -> tokens (row-major: h outer, w inner) + image_tokens = latents.permute(0, 2, 3, 1).reshape(b, num_image_tokens, c) + + # The mask may arrive as a float (PromptEmbeds.to casts it to the embed + # dtype); work in long so cumsum positions stay exact for long prompts. + text_mask_bool = text_mask.to(device) > 0 + text_mask_long = text_mask_bool.long() + + # noise tokens: text region is zeroed (masked out anyway) + x = torch.cat( + [ + torch.zeros(b, num_text_tokens, c, device=device, dtype=image_tokens.dtype), + image_tokens, + ], + dim=1, + ) + + # llm features: image region is zero + llm_full = torch.cat( + [ + llm_features, + torch.zeros( + b, + num_image_tokens, + llm_features.shape[-1], + device=device, + dtype=llm_features.dtype, + ), + ], + dim=1, + ) + + # indicator: real text -> 3, image -> 2, text pad -> 0 + indicator = torch.zeros(b, seq_len, dtype=torch.long, device=device) + indicator[:, :num_text_tokens] = text_mask_long * LLM_TOKEN_INDICATOR + indicator[:, num_text_tokens:] = OUTPUT_IMAGE_INDICATOR + + # segment ids: real text + image -> 1, text pad -> -1 (its own padding segment) + segment_ids = torch.ones(b, seq_len, dtype=torch.long, device=device) + segment_ids[:, :num_text_tokens] = torch.where( + text_mask_bool, + torch.ones_like(text_mask_long), + torch.full_like(text_mask_long, SEQUENCE_PADDING_INDICATOR), + ) + + # position ids (t, h, w) + # text positions: 0..num_real-1 at the real slots (relative; pad -> 0) + text_pos = (text_mask_long.cumsum(dim=-1) - 1).clamp(min=0) # (B, Lt) + text_pos_3d = text_pos.unsqueeze(-1).expand(-1, -1, 3) + + h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1) + w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1) + t_idx = torch.zeros_like(h_idx) + image_pos = torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET + image_pos_3d = image_pos.unsqueeze(0).expand(b, -1, -1) + + position_ids = torch.cat([text_pos_3d, image_pos_3d], dim=1) + + # Flip into the model's time convention (t=1 -> clean). + model_t = 1.0 - t + + out = transformer( + llm_features=llm_full, + x=x, + t=model_t, + position_ids=position_ids, + segment_ids=segment_ids, + indicator=indicator, + ) + + image_velocity = out[:, num_text_tokens:] # (B, Li, 128) + image_velocity = image_velocity.reshape(b, gh, gw, c).permute(0, 3, 1, 2) + # Model predicts clean - noise; negate to return toolkit velocity (noise - clean). + return -image_velocity + + +# --------------------------------------------------------------------------- +# Minimal sampling pipeline (for training previews). +# --------------------------------------------------------------------------- + + +class Ideogram4Pipeline: + """Lightweight flow-matching sampler used by ai-toolkit's preview generation.""" + + def __init__(self, model): + # ``model`` is the Ideogram4Model so we can reuse its encode/decode and + # latent helpers without duplicating state. + self.model = model + + @property + def device(self): + return self.model.device_torch + + def to(self, *args, **kwargs): + return self + + @torch.no_grad() + def __call__( + self, + conditional_embeds, + unconditional_embeds, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 30, + guidance_scale: float = 7.0, + latents: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[Image.Image]: + model = self.model + device = model.device_torch + dtype = model.torch_dtype + transformer = model.transformer + patch = model.patch_size + + schedule_mu = float( + model.model_config.model_kwargs.get("ideogram_schedule_mu", 0.0) + ) + schedule_std = float( + model.model_config.model_kwargs.get("ideogram_schedule_std", 1.75) + ) + sigmas = get_ideogram4_sigmas( + num_inference_steps, + width, + height, + mu=schedule_mu, + std=schedule_std, + device=device, + ) + + ae_scale = model.vae_scale_factor # 8 + gh = height // (ae_scale * patch) + gw = width // (ae_scale * patch) + latent_channels = transformer.config.in_channels + + # Ideogram uses asymmetric CFG: the unconditional branch is image-only + # (no text tokens) with zeroed text features -- it does NOT run a negative + # prompt through the text encoder. So we ignore unconditional_embeds and + # build an empty (0-length) text sequence for the uncond pass below. + do_cfg = guidance_scale > 1.0 + + if latents is None: + shape = (1, latent_channels, gh, gw) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=torch.float32 + ) + latents = latents.to(device, dtype=torch.float32) + latents = latents * sigmas[0] + + cond_feats, cond_mask = pad_text_features( + conditional_embeds.text_embeds, device, dtype + ) + if do_cfg: + # Image-only unconditional: zero-length text sequence. predict_velocity + # then produces an image-token-only forward pass with zeroed llm + # features, matching the reference's asymmetric CFG. + batch_size = latents.shape[0] + text_dim = cond_feats.shape[-1] + uncond_feats = torch.zeros( + batch_size, 0, text_dim, device=device, dtype=dtype + ) + uncond_mask = torch.zeros(batch_size, 0, dtype=torch.long, device=device) + + # The unconditional LoRA (if present) must be active *only* on the + # unconditional pass. We force it off before each conditional pass since the + # outer sampling context (``with network:``) may switch it on globally. + uncond_lora = getattr(model, "unconditional_lora", None) + + for sigma, sigma_next in zip(sigmas[:-1], sigmas[1:]): + t01 = sigma.expand(latents.shape[0]) + if uncond_lora is not None: + uncond_lora.is_active = False + v_cond = predict_velocity( + transformer, latents.to(dtype), t01, cond_feats, cond_mask + ) + if do_cfg: + if uncond_lora is not None: + uncond_lora.is_active = True + try: + v_uncond = predict_velocity( + transformer, latents.to(dtype), t01, uncond_feats, uncond_mask + ) + finally: + if uncond_lora is not None: + uncond_lora.is_active = False + v = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v = v_cond + latents = latents + v.to(torch.float32) * (sigma_next - sigma) + + images = model.decode_latents(latents, device=device, dtype=dtype) + images = images.float().clamp(-1.0, 1.0) + images = ((images + 1.0) * 127.5).round().to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy() + return [Image.fromarray(arr) for arr in images] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/transformer.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3173bda3654deba25f16431a666db7a0d253a559 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/transformer.py @@ -0,0 +1,534 @@ +"""Ideogram4 transformer backbone. + +Ported from the reference ``modeling_ideogram4.py`` for ai-toolkit. The +transformer consumes Qwen3-VL hidden states and flow-matching noise tokens +(packed into a single sequence) and produces velocity predictions on the image +latent tokens. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +try: + from flash_attn import flash_attn_varlen_func + + _FLASH_ATTN_AVAILABLE = True +except ImportError: + flash_attn_varlen_func = None + _FLASH_ATTN_AVAILABLE = False + +# Supported attention backends. "native" -> SDPA, "flash" -> Flash Attention 2. +ATTENTION_BACKENDS = ("native", "flash") + +# Per-token role indicators used inside the packed sequence. +SEQUENCE_PADDING_INDICATOR = -1 +OUTPUT_IMAGE_INDICATOR = 2 +LLM_TOKEN_INDICATOR = 3 + +# Image grid coordinates start at this offset so they never collide with text +# token indices (text positions start at 0 and never exceed max_text_tokens). +IMAGE_POSITION_OFFSET = 65536 + +# Layers of Qwen3-VL whose hidden states are concatenated and fed to the transformer. +QWEN3_VL_ACTIVATION_LAYERS = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 35) + + +@dataclass +class Ideogram4Config: + emb_dim: int = 4608 + num_layers: int = 34 + num_heads: int = 18 + intermediate_size: int = 12288 + adanln_dim: int = 512 + + # Latent dimension after patchification: ae_channels (32) * patch_size**2 (4) = 128. + in_channels: int = 128 + + # Hidden size of Qwen3-VL-8B-Instruct multiplied by the number of extracted layers. + # Qwen3-VL hidden size = 4096 + llm_features_dim: int = 4096 * len(QWEN3_VL_ACTIVATION_LAYERS) + + rope_theta: int = 5_000_000 + mrope_section: tuple[int, ...] = (24, 20, 20) + + norm_eps: float = 1e-5 + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + half = x.shape[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + # q, k: (B, num_heads, L, head_dim); cos/sin: (B, L, head_dim). + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class Ideogram4MRoPE(nn.Module): + inv_freq: torch.Tensor + + def __init__( + self, + head_dim: int, + base: int, + mrope_section: tuple[int, ...], + ) -> None: + super().__init__() + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.mrope_section = tuple(mrope_section) + self.head_dim = head_dim + + @torch.no_grad() + def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # position_ids: (B, L, 3) of int. + assert position_ids.ndim == 3 and position_ids.shape[-1] == 3 + batch_size, seq_len, _ = position_ids.shape + + if self.inv_freq.device == torch.device("cpu"): + # sometimes it gets stuck on CPU + self.inv_freq = self.inv_freq.to(position_ids.device) + + # (3, B, inv_freq_size, L) + pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32) + inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand( + 3, batch_size, -1, 1 + ) + freqs = inv_freq @ pos.unsqueeze(2) + freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size) + + # interleaved mrope: pull H freqs into idx 1 mod 3, W freqs into idx 2 mod 3. + freqs_t = freqs[0].clone() + for axis, offset in ((1, 1), (2, 2)): + length = self.mrope_section[axis] * 3 + idx = torch.arange(offset, length, 3, device=freqs_t.device) + freqs_t[..., idx] = freqs[axis][..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + return emb.cos(), emb.sin() + + +class Ideogram4RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.rms_norm(x, self.weight.shape, self.weight, self.eps) + + +def _build_flash_meta( + segment_ids: torch.Tensor, +) -> tuple[torch.Tensor, int, torch.Tensor, torch.Tensor]: + """Derive Flash Attention 2 packing metadata from segment ids. + + Tokens attend to each other iff they share a ``(batch_row, segment_id)`` + group, exactly matching the SDPA block-diagonal mask. The groups are NOT + contiguous in the packed layout -- e.g. ``[real_text | text_pad | image]`` + gives real-text and image the same segment id but splits them with the pad + run -- so flash (which only attends over contiguous ``cu_seqlens`` ranges) + can't consume the sequence as-is. We sort tokens into contiguous groups, + build ``cu_seqlens`` over the sorted order, and return the permutation plus + its inverse so the attention output can be scattered back to the original + token order. + + Returns ``(cu_seqlens, max_seqlen, order, inv_order)`` where ``order`` and + ``inv_order`` index the flattened ``(B * L,)`` token axis. + """ + batch_size, _ = segment_ids.shape + device = segment_ids.device + + # Unique group id per (row, segment). Shift so the -1 pad segment is >= 0. + seg = segment_ids.to(torch.long) + seg_shifted = seg - int(seg.min()) + num_seg = int(seg_shifted.max()) + 1 + row = torch.arange(batch_size, device=device).unsqueeze(1) + group = (row * num_seg + seg_shifted).reshape(-1) + + order = torch.argsort(group, stable=True) + inv_order = torch.argsort(order, stable=True) + sorted_group = group[order] + + change = torch.ones_like(sorted_group, dtype=torch.bool) + change[1:] = sorted_group[1:] != sorted_group[:-1] + boundaries = torch.nonzero(change, as_tuple=False).flatten() + + total = torch.tensor([sorted_group.numel()], device=device, dtype=boundaries.dtype) + cu_seqlens = torch.cat([boundaries, total]).to(torch.int32) + max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max()) + return cu_seqlens, max_seqlen, order, inv_order + + +class Ideogram4Attention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-5) -> None: + super().__init__() + assert hidden_size % num_heads == 0 + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.attention_backend = "native" + + self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.norm_q = Ideogram4RMSNorm(self.head_dim, eps=eps) + self.norm_k = Ideogram4RMSNorm(self.head_dim, eps=eps) + self.o = nn.Linear(hidden_size, hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + flash_meta: tuple[torch.Tensor, int, torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + + qkv = self.qkv(x) + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(dim=2) + + q = self.norm_q(q) + k = self.norm_k(k) + + # SDPA / rope expect (B, num_heads, L, head_dim). + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + if self.attention_backend == "flash": + # Flash Attention 2 takes packed (total_tokens, num_heads, head_dim) + # tensors and expresses the block-diagonal structure via cu_seqlens + # over contiguous ranges. The attention groups aren't contiguous in + # the packed layout, so we reorder tokens into their groups, run + # flash, then scatter the result back to the original order. + cu_seqlens, max_seqlen, order, inv_order = flash_meta + qf = q.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + kf = k.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + vf = v.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + qf = qf.index_select(0, order) + kf = kf.index_select(0, order) + vf = vf.index_select(0, order) + out = flash_attn_varlen_func( + qf, + kf, + vf, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=False, + ) + out = out.index_select(0, inv_order) + out = out.reshape(batch_size, seq_len, self.hidden_size) + else: + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + out = out.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size) + return self.o(out) + + +class Ideogram4MLP(nn.Module): + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Ideogram4TransformerBlock(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_heads: int, + norm_eps: float, + adanln_dim: int, + ) -> None: + super().__init__() + self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5) + self.feed_forward = Ideogram4MLP(hidden_size, intermediate_size) + + self.attention_norm1 = Ideogram4RMSNorm(hidden_size, eps=norm_eps) + self.ffn_norm1 = Ideogram4RMSNorm(hidden_size, eps=norm_eps) + self.attention_norm2 = Ideogram4RMSNorm(hidden_size, eps=norm_eps) + self.ffn_norm2 = Ideogram4RMSNorm(hidden_size, eps=norm_eps) + + self.adaln_modulation = nn.Linear(adanln_dim, 4 * hidden_size, bias=True) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + adaln_input: torch.Tensor, + flash_meta: tuple[torch.Tensor, int, torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + mod = self.adaln_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1) + gate_msa = torch.tanh(gate_msa) + gate_mlp = torch.tanh(gate_mlp) + scale_msa = 1.0 + scale_msa + scale_mlp = 1.0 + scale_mlp + + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + attn_mask=attn_mask, + cos=cos, + sin=sin, + flash_meta=flash_meta, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward(self.ffn_norm1(x) * scale_mlp) + ) + return x + + +def _sinusoidal_embedding( + t: torch.Tensor, dim: int, scale: float = 1e4 +) -> torch.Tensor: + t = t.to(torch.float32) + half = dim // 2 + freq = math.log(scale) / (half - 1) + freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq) + emb = t.unsqueeze(-1) * freq + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + +class Ideogram4EmbedScalar(nn.Module): + def __init__(self, dim: int, input_range: tuple[float, float]) -> None: + super().__init__() + self.dim = dim + self.range_min, self.range_max = input_range + assert self.range_max > self.range_min + self.mlp_in = nn.Linear(dim, dim, bias=True) + self.mlp_out = nn.Linear(dim, dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x is shape (..., 1) or (...,) holding a scalar per token. + x = x.to(torch.float32) + scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min) + emb = _sinusoidal_embedding(scaled, self.dim) + emb = emb.to(self.mlp_in.weight.dtype) + emb = F.silu(self.mlp_in(emb)) + return self.mlp_out(emb) + + +class Ideogram4FinalLayer(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, adanln_dim: int) -> None: + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaln_modulation = nn.Linear(adanln_dim, hidden_size, bias=True) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + scale = 1.0 + self.adaln_modulation(F.silu(c)) + return self.linear(self.norm_final(x) * scale) + + +class Ideogram4Transformer2DModel(nn.Module): + """Ideogram 4 flow-matching transformer.""" + + def __init__(self, config: Ideogram4Config) -> None: + super().__init__() + self.config = config + self.gradient_checkpointing = False + self.attention_backend = "native" + + head_dim = config.emb_dim // config.num_heads + + self.input_proj = nn.Linear(config.in_channels, config.emb_dim, bias=True) + self.llm_cond_norm = Ideogram4RMSNorm(config.llm_features_dim, eps=1e-6) + self.llm_cond_proj = nn.Linear( + config.llm_features_dim, config.emb_dim, bias=True + ) + self.t_embedding = Ideogram4EmbedScalar(config.emb_dim, input_range=(0.0, 1.0)) + self.adaln_proj = nn.Linear(config.emb_dim, config.adanln_dim, bias=True) + + self.embed_image_indicator = nn.Embedding(2, config.emb_dim) + + self.rotary_emb = Ideogram4MRoPE( + head_dim=head_dim, + base=config.rope_theta, + mrope_section=config.mrope_section, + ) + + self.layers = nn.ModuleList( + [ + Ideogram4TransformerBlock( + hidden_size=config.emb_dim, + intermediate_size=config.intermediate_size, + num_heads=config.num_heads, + norm_eps=config.norm_eps, + adanln_dim=config.adanln_dim, + ) + for _ in range(config.num_layers) + ] + ) + + self.final_layer = Ideogram4FinalLayer( + hidden_size=config.emb_dim, + out_channels=config.in_channels, + adanln_dim=config.adanln_dim, + ) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def set_attention_backend(self, backend: str) -> None: + """Select the attention implementation. + + Args: + backend: "native" for ``F.scaled_dot_product_attention`` or "flash" + for Flash Attention 2 (``flash_attn_varlen_func``). Selecting "flash" + requires the ``flash_attn`` package to be installed. + """ + backend = backend.lower() + if backend not in ATTENTION_BACKENDS: + raise ValueError( + f"Unknown attention backend {backend!r}. " + f"Expected one of {ATTENTION_BACKENDS}." + ) + if backend == "flash" and not _FLASH_ATTN_AVAILABLE: + raise RuntimeError( + "Flash attention 2 backend requested but the `flash_attn` package " + "is not installed. Install it with `pip install flash-attn` or use " + "the 'native' backend." + ) + self.attention_backend = backend + for layer in self.layers: + layer.attention.attention_backend = backend + + def forward( + self, + *, + llm_features: torch.Tensor, + x: torch.Tensor, + t: torch.Tensor, + position_ids: torch.Tensor, + segment_ids: torch.Tensor, + indicator: torch.Tensor, + ) -> torch.Tensor: + """Velocity prediction. + + Args: + llm_features: (B, L, llm_features_dim) Qwen3-VL conditioning features. + x: (B, L, in_channels) noise tokens. + t: (B,) or (B, L) flow-matching time in [0, 1]. + position_ids: (B, L, 3) (t, h, w) positions for MRoPE. + segment_ids: (B, L) sample id within a packed batch. + indicator: (B, L) per-token role: LLM_TOKEN_INDICATOR or OUTPUT_IMAGE_INDICATOR. + + Returns: + (B, L, in_channels) velocity prediction in float32. Only the positions + with ``indicator == OUTPUT_IMAGE_INDICATOR`` are meaningful. + """ + batch_size, seq_len, in_channels = x.shape + assert in_channels == self.config.in_channels + + param_dtype = self.input_proj.weight.dtype + x = x.to(param_dtype) + t = t.to(param_dtype) + llm_features = llm_features.to(param_dtype) + + indicator = indicator.to(torch.long) + llm_token_mask = (indicator == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1) + output_image_mask = ( + (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1) + ) + + llm_features = llm_features * llm_token_mask + x = x * output_image_mask + + x = self.input_proj(x) * output_image_mask + + # Keep shape (B, 1, ...) when t is per-sample so downstream adaln_modulation + # projections don't pay for L identical copies. + t_cond = self.t_embedding(t) + if t.dim() == 1: + t_cond = t_cond.unsqueeze(1) + adaln_input = F.silu(self.adaln_proj(t_cond)) + + llm_features = self.llm_cond_norm(llm_features) + llm_features = self.llm_cond_proj(llm_features) * llm_token_mask + + h = x + llm_features + + image_indicator_embedding = self.embed_image_indicator( + (indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long) + ) + h = h + image_indicator_embedding + + cos, sin = self.rotary_emb(position_ids) + cos = cos.to(h.dtype) + sin = sin.to(h.dtype) + + # Block-diagonal mask from segment ids: (B, 1, L, L), True = attend. + # Only built for the native (SDPA) backend; flash expresses the same + # block structure through cu_seqlens instead. + if self.attention_backend == "flash": + attn_mask = None + flash_meta = _build_flash_meta(segment_ids) + else: + attn_mask = ( + segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1) + ).unsqueeze(1) + flash_meta = None + + for layer in self.layers: + if self.gradient_checkpointing and torch.is_grad_enabled(): + h = checkpoint( + layer, + h, + attn_mask, + cos, + sin, + adaln_input, + flash_meta, + use_reentrant=False, + ) + else: + h = layer(h, attn_mask, cos, sin, adaln_input, flash_meta) + + out = self.final_layer(h, c=adaln_input) + return out.to(torch.float32) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/vae.py b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..d03a2e2084016a77578fa5b46eadd6cba5adc0da --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ideogram4/src/vae.py @@ -0,0 +1,482 @@ +"""Flux2 KL autoencoder.""" + +from __future__ import annotations + +import math +import re +from dataclasses import dataclass, field + +import torch +import torch.utils.checkpoint as ckpt +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) + num_res_blocks: int = 2 + z_channels: int = 32 + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.down[i_level].block[i_block], hs[-1]) # type: ignore[index, operator] + if len(self.down[i_level].attn) > 0: # type: ignore[arg-type] + h = ckpt.checkpoint(self.down[i_level].attn[i_block], h) # type: ignore[index, operator] + else: + h = self.down[i_level].block[i_block](hs[-1]) # type: ignore[index, operator] + if len(self.down[i_level].attn) > 0: # type: ignore[arg-type] + h = self.down[i_level].attn[i_block](h) # type: ignore[index, operator] + hs.append(h) + if i_level != self.num_resolutions - 1: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hs.append(ckpt.checkpoint(self.down[i_level].downsample, hs[-1])) # type: ignore[operator] + else: + hs.append(self.down[i_level].downsample(hs[-1])) # type: ignore[operator] + + # middle + h = hs[-1] + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.mid.block_1, h) # type: ignore[operator] + h = ckpt.checkpoint(self.mid.attn_1, h) # type: ignore[operator] + h = ckpt.checkpoint(self.mid.block_2, h) # type: ignore[operator] + else: + h = self.mid.block_1(h) # type: ignore[operator] + h = self.mid.attn_1(h) # type: ignore[operator] + h = self.mid.block_2(h) # type: ignore[operator] + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h = self.quant_conv(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def forward(self, z: Tensor) -> Tensor: + z = self.post_quant_conv(z) + + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.mid.block_1, h) # type: ignore[operator] + h = ckpt.checkpoint(self.mid.attn_1, h) # type: ignore[operator] + h = ckpt.checkpoint(self.mid.block_2, h) # type: ignore[operator] + else: + h = self.mid.block_1(h) # type: ignore[operator] + h = self.mid.attn_1(h) # type: ignore[operator] + h = self.mid.block_2(h) # type: ignore[operator] + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.up[i_level].block[i_block], h) # type: ignore[index, operator] + if len(self.up[i_level].attn) > 0: # type: ignore[arg-type] + h = ckpt.checkpoint(self.up[i_level].attn[i_block], h) # type: ignore[index, operator] + else: + h = self.up[i_level].block[i_block](h) # type: ignore[index, operator] + if len(self.up[i_level].attn) > 0: # type: ignore[arg-type] + h = self.up[i_level].attn[i_block](h) # type: ignore[index, operator] + if i_level != 0: + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = ckpt.checkpoint(self.up[i_level].upsample, h) # type: ignore[operator] + else: + h = self.up[i_level].upsample(h) # type: ignore[operator] + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d( + math.prod(self.ps) * params.z_channels, + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self._gradient_checkpointing = False + + @property + def gradient_checkpointing(self) -> bool: + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, value: bool): + self._gradient_checkpointing = value + self.encoder.gradient_checkpointing = value + self.decoder.gradient_checkpointing = value + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + self.encoder.enable_gradient_checkpointing() + self.decoder.enable_gradient_checkpointing() + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + +_NUM_RESOLUTIONS = 4 + + +def convert_diffusers_state_dict(src: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + attn_substrings = (".mid.attn_1.",) + for src_key, tensor in src.items(): + dst_key = _rewrite_diffusers_key(src_key) + if dst_key is None: + raise KeyError(f"Unrecognized diffusers VAE state-dict key: {src_key}") + if ( + any(s in dst_key for s in attn_substrings) + and dst_key.endswith(".weight") + and tensor.ndim == 2 + ): + tensor = tensor.unsqueeze(-1).unsqueeze(-1) + out[dst_key] = tensor + return out + + +def _rewrite_diffusers_key(key: str) -> str | None: + if key.startswith("bn."): + return key + + if key.startswith("quant_conv."): + return key.replace("quant_conv.", "encoder.quant_conv.", 1) + if key.startswith("post_quant_conv."): + return key.replace("post_quant_conv.", "decoder.post_quant_conv.", 1) + + if key == "encoder.conv_norm_out.weight": + return "encoder.norm_out.weight" + if key == "encoder.conv_norm_out.bias": + return "encoder.norm_out.bias" + if key == "decoder.conv_norm_out.weight": + return "decoder.norm_out.weight" + if key == "decoder.conv_norm_out.bias": + return "decoder.norm_out.bias" + + m = re.match(r"^(encoder|decoder)\.mid_block\.resnets\.(\d+)\.(.+)$", key) + if m: + side, idx, rest = m.group(1), int(m.group(2)), m.group(3) + rest = rest.replace("conv_shortcut", "nin_shortcut") + return f"{side}.mid.block_{idx + 1}.{rest}" + m = re.match(r"^(encoder|decoder)\.mid_block\.attentions\.0\.(.+)$", key) + if m: + side, rest = m.group(1), m.group(2) + rest = ( + rest.replace("group_norm.", "norm.") + .replace("to_q.", "q.") + .replace("to_k.", "k.") + .replace("to_v.", "v.") + .replace("to_out.0.", "proj_out.") + ) + return f"{side}.mid.attn_1.{rest}" + + m = re.match(r"^encoder\.down_blocks\.(\d+)\.resnets\.(\d+)\.(.+)$", key) + if m: + level, res_idx, rest = m.group(1), m.group(2), m.group(3) + rest = rest.replace("conv_shortcut", "nin_shortcut") + return f"encoder.down.{level}.block.{res_idx}.{rest}" + m = re.match(r"^encoder\.down_blocks\.(\d+)\.downsamplers\.0\.conv\.(.+)$", key) + if m: + return f"encoder.down.{m.group(1)}.downsample.conv.{m.group(2)}" + + m = re.match(r"^decoder\.up_blocks\.(\d+)\.resnets\.(\d+)\.(.+)$", key) + if m: + diffusers_idx = int(m.group(1)) + res_idx = m.group(2) + rest = m.group(3).replace("conv_shortcut", "nin_shortcut") + return ( + f"decoder.up.{_NUM_RESOLUTIONS - 1 - diffusers_idx}.block.{res_idx}.{rest}" + ) + m = re.match(r"^decoder\.up_blocks\.(\d+)\.upsamplers\.0\.conv\.(.+)$", key) + if m: + diffusers_idx = int(m.group(1)) + return f"decoder.up.{_NUM_RESOLUTIONS - 1 - diffusers_idx}.upsample.conv.{m.group(2)}" + + if key.startswith( + ( + "encoder.conv_in.", + "encoder.conv_out.", + "decoder.conv_in.", + "decoder.conv_out.", + ) + ): + return key + + return None diff --git a/ai-toolkit/extensions_built_in/diffusion_models/krea2/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/krea2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b61a843004793950428ddb4e682ed2d79995fd51 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/krea2/__init__.py @@ -0,0 +1,3 @@ +from .krea2 import Krea2Model + +__all__ = ["Krea2Model"] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/krea2/krea2.py b/ai-toolkit/extensions_built_in/diffusion_models/krea2/krea2.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa90d90d56d5d9c64405cf2260c3ec9e92c81c5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/krea2/krea2.py @@ -0,0 +1,607 @@ +"""Krea 2 (K2) for ai-toolkit. + +Krea 2 is a single-stream MMDiT text-to-image model: + - text encoder: Qwen3-VL-4B-Instruct (a stack of 12 hidden-state layers is fed + in; ``src/text_encoder.py``), + - autoencoder: the Qwen-Image VAE (f8, 16 latent channels, the same VAE the + ``qwen_image`` arch uses), + - denoiser: ``SingleStreamDiT`` (``src/mmdit.py``), which fuses the text layers + with a small ``TextFusionTransformer`` and runs the packed [text | image] + sequence through ``SingleStreamBlock`` layers. + +Flow-matching convention matches ai-toolkit exactly (t=1 noise -> t=0 clean, +target = noise - clean), so ``get_noise_prediction`` does no time flip / negation. +""" + +import os +from typing import List, Optional + +import torch +from safetensors.torch import load_file, save_file + +import huggingface_hub +from huggingface_hub.errors import EntryNotFoundError +from diffusers import AutoencoderKLQwenImage +from transformers import ( + AutoTokenizer, + Qwen2TokenizerFast, + Qwen3VLForConditionalGeneration, +) +from optimum.quanto import freeze, QTensor + +from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from toolkit.metadata import get_meta_for_safetensors +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager + +from .src.mmdit import ( + DoubleSharedModulation, + SimpleModulation, + SingleMMDiTConfig, + SingleStreamDiT, +) +from .src.text_encoder import encode_krea_prompt, SELECT_LAYERS +from .src.pipeline import Krea2Pipeline, pad_text_features, predict_velocity + + +# The reference "single_mmdit_large_wide" architecture (oss_raw / oss_turbo share it). +KREA2_MMDIT_CONFIG = dict( + features=6144, + tdim=256, + txtdim=2560, + heads=48, + kvheads=12, + multiplier=4, + layers=28, + patch=2, + channels=16, + txtheads=20, + txtkvheads=20, + txtlayers=12, +) + +# Krea 2's mu schedule is exponential time-shifting whose mu is linearly +# interpolated in image-token count between (256-res -> 0.5) and (1280-res -> +# 1.15) -- exactly what CustomFlowMatchEulerDiscreteScheduler's dynamic shifting +# does, so we mirror those endpoints here for the training timestep distribution. +# x1 = (256 // (8*2))**2 = 256 +# x2 = (1280 // (8*2))**2 = 6400 +scheduler_config = { + "base_image_seq_len": 256, + "max_image_seq_len": 6400, + "base_shift": 0.5, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 1.0, + "use_dynamic_shifting": True, + "time_shift_type": "exponential", +} + +# Defaults; both overridable via model.model_kwargs. +QWEN3_VL_PATH = "Qwen/Qwen3-VL-4B-Instruct" +QWEN_IMAGE_VAE_PATH = "Qwen/Qwen-Image" + +HF_TOKEN = os.getenv("HF_TOKEN", None) + + +def _load_mmdit_state_dict(name_or_path: str, filename: Optional[str]) -> dict: + """Load the MMDiT weights from a local safetensors file/dir or the HF hub. + + ``name_or_path`` may be: a ``.safetensors`` file, a directory containing one + (``filename`` or the lone ``.safetensors`` in it), or a hub repo id (the + file ``filename`` is downloaded, defaulting to ``model.safetensors``). + """ + if name_or_path.endswith(".safetensors") and os.path.isfile(name_or_path): + return load_file(name_or_path) + + if os.path.isdir(name_or_path): + if filename is not None: + return load_file(os.path.join(name_or_path, filename)) + candidates = [f for f in os.listdir(name_or_path) if f.endswith(".safetensors")] + if len(candidates) == 1: + return load_file(os.path.join(name_or_path, candidates[0])) + raise FileNotFoundError( + f"Could not pick an MMDiT checkpoint in {name_or_path}: found " + f"{candidates}. Set model.model_kwargs.checkpoint_filename." + ) + + # Treat as a hub repo id. When no filename is given, derive it from the repo + # name's trailing segment (e.g. "krea/Krea-2-Raw" -> "raw.safetensors", + # "krea/Krea-2-Turbo" -> "turbo.safetensors"). + fname = filename or ( + name_or_path.split("/")[-1].split("-")[-1].lower() + ".safetensors" + ) + try: + path = huggingface_hub.hf_hub_download( + repo_id=name_or_path, filename=fname, token=HF_TOKEN + ) + except EntryNotFoundError as e: + raise FileNotFoundError( + f"Could not find {fname!r} in hub repo {name_or_path!r}. Set " + "model.model_kwargs.checkpoint_filename to the weight file name." + ) from e + return load_file(path) + + +class Krea2Model(BaseModel): + arch = "krea2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["SingleStreamDiT"] + + self.patch_size = KREA2_MMDIT_CONFIG["patch"] + self.vae_scale_factor = 8 # Qwen-Image VAE is f8 + # Safety cap on prompt token length (truncation only); embeds are stored + # per-sample at natural length and padded to the batch max at the model call. + self.max_text_length = int( + self.model_config.model_kwargs.get("max_text_length", 512) + ) + # Qwen2TokenizerFast used to tokenize the assistant suffix (matches the + # reference's separate processor pass). + self.processor = None + self.use_old_lokr_format = False + + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # 8 for the VAE downsample, 2 for the patch size. + return self.vae_scale_factor * self.patch_size + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + def _load_transformer(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading transformer (SingleStreamDiT)") + + mmdit_kwargs = dict(KREA2_MMDIT_CONFIG) + mmdit_kwargs.update(self.model_config.model_kwargs.get("mmdit_config", {})) + config = SingleMMDiTConfig(**mmdit_kwargs) + + # Build on meta, then materialize straight from the checkpoint. + with torch.device("meta"): + transformer = SingleStreamDiT(config) + + self.print_and_status_update(" - fetching transformer weights") + state_dict = _load_mmdit_state_dict( + self.model_config.name_or_path, + self.model_config.model_kwargs.get("checkpoint_filename", None), + ) + state_dict = { + k: (v.to(dtype) if v.is_floating_point() else v) + for k, v in state_dict.items() + } + self.print_and_status_update(" - loading transformer state dict") + transformer.load_state_dict(state_dict, strict=True, assign=True) + del state_dict + flush() + return transformer + + def _load_text_encoder(self): + dtype = self.torch_dtype + te_path = self.model_config.model_kwargs.get("text_encoder_path", QWEN3_VL_PATH) + self.print_and_status_update(f"Loading Qwen3-VL text encoder from {te_path}") + + tokenizer = AutoTokenizer.from_pretrained( + te_path, max_length=self.max_text_length, token=HF_TOKEN + ) + processor = Qwen2TokenizerFast.from_pretrained( + te_path, max_length=self.max_text_length, token=HF_TOKEN + ) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + te_path, torch_dtype=dtype, token=HF_TOKEN + ) + # We only ever encode text, so the vision tower is dead weight -- drop it to + # free VRAM and skip loading its (bf16-slow) Conv3d patch_embed onto the GPU. + if getattr(text_encoder.model, "visual", None) is not None: + text_encoder.model.visual = None + text_encoder.eval() + text_encoder.requires_grad_(False) + flush() + return tokenizer, processor, text_encoder + + def _load_vae(self): + vae_path = self.model_config.model_kwargs.get("vae_path", QWEN_IMAGE_VAE_PATH) + self.print_and_status_update(f"Loading Qwen-Image VAE from {vae_path}") + vae = AutoencoderKLQwenImage.from_pretrained( + vae_path, subfolder="vae", torch_dtype=self.vae_torch_dtype, token=HF_TOKEN + ) + vae.eval() + vae.requires_grad_(False) + return vae + + def load_training_adapter(self, transformer: SingleStreamDiT): + self.print_and_status_update("Loading assistant LoRA") + lora_path = self.model_config.assistant_lora_path + if not os.path.exists(lora_path): + # assume it is a hub path + lora_splits = lora_path.split("/") + if len(lora_splits) != 3: + raise ValueError( + f"Assistant LoRA path {lora_path} is not a valid local path or hub path." + ) + repo_id = "/".join(lora_splits[:2]) + filename = lora_splits[2] + try: + lora_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + token=HF_TOKEN, + ) + # upgrade path to the local download + self.model_config.assistant_lora_path = lora_path + except Exception as e: + raise ValueError( + f"Failed to download assistant LoRA from {lora_path}: {e}" + ) + # load the adapter and merge it in. We will inference with a -1.0 multiplier so the adapter effects only work during training. + lora_state_dict = load_file(lora_path) + # detect the LoRA rank from the first down-projection weight. + dim_key = next(k for k in lora_state_dict if k.endswith("lora_A.weight")) + dim = int(lora_state_dict[dim_key].shape[0]) + + new_sd = {} + for key, value in lora_state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + lora_state_dict = new_sd + + network_config = { + "type": "lora", + "linear": dim, + "linear_alpha": dim, + "transformer_only": True, + } + + network_config = NetworkConfig(**network_config) + LoRASpecialNetwork.LORA_PREFIX_UNET = "lora_transformer" + network = LoRASpecialNetwork( + text_encoder=None, + unet=transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_transformer=True, + target_lin_modules=self.target_lora_modules, + is_assistant_adapter=True, + is_ara=True, + ) + network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + self.print_and_status_update("Merging in assistant LoRA") + network.force_to(self.device_torch, dtype=self.torch_dtype) + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + + network.merge_in(merge_weight=1.0) + + # mark it as not merged so inference ignores it. + network.is_merged_in = False + + # add the assistant so sampler will activate it while sampling + self.assistant_lora: LoRASpecialNetwork = network + + # deactivate lora during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + + # tell the model to invert assistant on inference since we want remove lora effects + self.invert_assistant_lora = True + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Krea 2 model") + + transformer = self._load_transformer() + + # load assistant lora if specified + if self.model_config.assistant_lora_path is not None: + self.load_training_adapter(transformer) + # set qtype to be float8 if it is qfloat8 + if self.model_config.qtype == "qfloat8": + self.model_config.qtype = "float8" + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + module + for module in transformer.modules() + if isinstance(module, (SimpleModulation, DoubleSharedModulation)) + ], + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + tokenizer, processor, text_encoder = self._load_text_encoder() + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing text encoder") + text_encoder.to(self.device_torch) + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving text encoder to CPU") + text_encoder.to("cpu") + else: + text_encoder.to(self.device_torch) + flush() + + vae = self._load_vae() + vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + + self.noise_scheduler = Krea2Model.get_train_scheduler() + + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.processor = processor + self.model = transformer + self.pipeline = Krea2Pipeline(self) + self.print_and_status_update("Model Loaded") + + # ------------------------------------------------------------------ + # Generation (training previews) + # ------------------------------------------------------------------ + def get_generation_pipeline(self): + return Krea2Pipeline(self) + + def generate_single_image( + self, + pipeline: Krea2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + )[0] + return img + + # ------------------------------------------------------------------ + # Training hooks + # ------------------------------------------------------------------ + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, # (B, 16, h, w) + timestep: torch.Tensor, # 0..1000 scale + text_embeddings: AdvancedPromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # toolkit timestep (0..1000, 1000 = pure noise) -> Krea flow time t in + # [0, 1] with t=1 = pure noise. Same convention -> straight divide. + t = timestep.to(self.device_torch, dtype=torch.float32) / 1000.0 + if t.dim() == 0: + t = t.unsqueeze(0) + if t.shape[0] != latent_model_input.shape[0]: + t = t.expand(latent_model_input.shape[0]) + + context, text_mask = pad_text_features( + text_embeddings.text_embeds, self.device_torch, self.torch_dtype + ) + + pred = predict_velocity( + self.transformer, + latent_model_input.to(self.device_torch, self.torch_dtype), + t, + context, + text_mask, + ) + return pred + + def get_prompt_embeds(self, prompt) -> AdvancedPromptEmbeds: + if isinstance(prompt, str): + prompt = [prompt] + + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + + # Encode each prompt at its natural length and store one (L, 12*2560) + # tensor per batch item. The (L, 12, 2560) stack is flattened to 2D so the + # toolkit's batching reads the list length (not the seq length) as the + # batch size; predict_velocity restores the layer axis. Padding to the + # batch max is deferred to the model call so caches stay small and any + # prompts can share a batch. + features_list = [] + for p in prompt: + features = encode_krea_prompt( + self.text_encoder, + self.tokenizer, + self.processor, + p, + max_length=self.max_text_length, + select_layers=SELECT_LAYERS, + ) + # (L, n, d) -> (L, n*d) + features = features.reshape(features.shape[0], -1) + features_list.append(features.to(self.torch_dtype)) + + return AdvancedPromptEmbeds(text_embeds=features_list) + + def get_loss_target(self, *args, **kwargs): + # Flow-matching velocity target: noise - clean. + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + # ------------------------------------------------------------------ + # VAE (Qwen-Image AutoencoderKLQwenImage -- same handling as qwen_image arch) + # ------------------------------------------------------------------ + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + + image_list = [image.to(device, dtype=dtype) for image in image_list] + images = torch.stack(image_list).to(device, dtype=dtype) + + # AutoencoderKLQwenImage is a video VAE: add a frame dim. + images = images.unsqueeze(2) + latents = self.vae.encode(images).latent_dist.sample() + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + + latents = (latents - latents_mean) * latents_std + latents = latents.squeeze(2) # drop frame dim + return latents.to(device, dtype=dtype) + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + + latents = latents.to(device, dtype=dtype) + latents = latents.unsqueeze(2) # add frame dim + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + + # Full-resolution decode spikes VRAM; tile it when low on VRAM (decode + # only -- encode stays untiled). + tiled = self.model_config.low_vram + if tiled: + self.vae.enable_tiling() + try: + images = self.vae.decode(latents).sample + finally: + if tiled: + self.vae.disable_tiling() + images = images.squeeze(2) # drop frame dim + return images.to(device, dtype=dtype) + + # ------------------------------------------------------------------ + # Saving / bookkeeping + # ------------------------------------------------------------------ + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + transformer: SingleStreamDiT = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + meta = get_meta_for_safetensors(meta, name="krea2") + save_file(save_dict, output_path, metadata=meta) + + def get_base_model_version(self): + return "krea2" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["blocks"] + + def convert_lora_weights_before_save(self, state_dict): + return { + k.replace("transformer.", "diffusion_model."): v + for k, v in state_dict.items() + } + + def convert_lora_weights_before_load(self, state_dict): + return { + k.replace("diffusion_model.", "transformer."): v + for k, v in state_dict.items() + } diff --git a/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/mmdit.py b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/mmdit.py new file mode 100644 index 0000000000000000000000000000000000000000..e40710f30e145ff6d7bbcd30fb1363f7d039bcb5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/mmdit.py @@ -0,0 +1,461 @@ +"""Krea 2 (K2) single-stream MMDiT backbone. + +Vendored from the reference ``mmdit.py`` for ai-toolkit. This is a single-stream +MMDiT: Qwen3-VL text features are fused by a small ``TextFusionTransformer`` and +then concatenated with the patchified image latent tokens into one sequence that +flows through ``SingleStreamBlock`` layers. The model predicts the flow-matching +velocity on the image tokens. + +Differences from the reference (all training-driven, numerically equivalent): + - ``torch.compile`` decorators are dropped (they fight gradient checkpointing, + LoRA module swapping and variable shapes during training). + - Attention uses a plain ``F.scaled_dot_product_attention`` instead of forcing + the cuDNN SDPA backend, so it works across dtypes / masks / backward. + - ``enable_gradient_checkpointing`` / ``disable_gradient_checkpointing`` and a + per-block ``torch.utils.checkpoint`` wrapper are added (gated on + ``torch.is_grad_enabled()`` so eval/sampling never pays for it). +""" + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.checkpoint import checkpoint + + +def rope(pos: Tensor, dim: int, theta: float = 1e4, ntk: float = 1.0) -> Tensor: + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / ((theta * ntk) ** scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def ropeapply(xq: Tensor, xk: Tensor, freqs: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + freqs = freqs[:, None, :, :, :] + xq_ = freqs[..., 0] * xq_[..., 0] + freqs[..., 1] * xq_[..., 1] + xk_ = freqs[..., 0] * xk_[..., 0] + freqs[..., 1] * xk_[..., 1] + return xq_.reshape(*xq.shape).to(xq.dtype), xk_.reshape(*xk.shape).to(xk.dtype) + + +def attention( + q: Tensor, + k: Tensor, + v: Tensor, + mask: Tensor | None = None, + scale: float | None = None, + gqa: bool = False, +) -> Tensor: + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, scale=scale, enable_gqa=gqa + ) + return rearrange(x, "B H L D -> B L (H D)") + + +def _mask(mask: Tensor) -> Tensor: + """Expand a (B, L) key-padding mask into a (B, 1, L, L) attention mask.""" + return mask.unsqueeze(1).unsqueeze(2) * mask.unsqueeze(1).unsqueeze(3) + + +def temb( + t: Tensor, + dim: int, + period: float = 1e4, + tfactor: float = 1e3, + device: torch.device = None, + dtype: torch.dtype = None, +) -> Tensor: + half = dim // 2 + freqs = torch.exp( + -math.log(period) + * torch.arange(half, dtype=torch.float32, device=device) + / half + ) + # t: (B,) -> args: (B, 1, half), so the embedding broadcasts as a per-sample vec. + args = (t.float() * tfactor)[:, None, None] * freqs + sin, cos = torch.sin(args), torch.cos(args) + return torch.cat((cos, sin), dim=-1).to(dtype=dtype) + + +@dataclass +class SingleMMDiTConfig: + features: int + tdim: int + txtdim: int + heads: int + multiplier: int + layers: int + patch: int + channels: int + bias: bool = False + theta: float = 1e3 + kvheads: int | None = None + txtlayers: int = 1 + txtheads: int = 20 + txtkvheads: int = 20 + + +class SimpleModulation(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin = torch.nn.Parameter(torch.zeros(2, dim)) + self.multiplier = 2 + + # vec (b d) + def forward(self, vec: Tensor): + out = vec + rearrange(self.lin, "two d -> 1 two d") + scale, shift = out.chunk(self.multiplier, dim=1) + return scale, shift + + +class DoubleSharedModulation(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin = torch.nn.Parameter(torch.zeros(6 * dim)) + + # vec (b (6 d)) + def forward(self, vec: Tensor): + out = vec + self.lin + prescale, preshift, pregate, postscale, postshift, postgate = out.chunk( + 6, dim=-1 + ) + return prescale, preshift, pregate, postscale, postshift, postgate + + +class PositionalEncoding(torch.nn.Module): + def __init__(self, dim, axdims: list[int], theta: float = 1e2, ntk: float = 1.0): + super().__init__() + self.axdims = axdims # how to split the head dimension across the position axes + self.theta = theta + self.ntk = ntk + + def forward(self, pos: Tensor) -> Tensor: + return torch.cat( + [ + rope(pos[..., i], d, self.theta, self.ntk) + for i, d in enumerate(self.axdims) + ], + dim=-3, + ) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.qnorm = RMSNorm(dim) + self.knorm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor, Tensor]: + return self.qnorm(q), self.knorm(k), v + + +class RMSNorm(torch.nn.Module): + def __init__(self, features: int, eps: float = 1e-05, device: torch.device = None): + super().__init__() + self.features = features + self.eps = eps + self.scale = torch.nn.Parameter( + torch.zeros(features, device=device, dtype=torch.float32) + ) + + def forward(self, x: Tensor) -> Tensor: + t, dtype = x.float(), x.dtype + t = F.rms_norm( + t, (self.features,), eps=self.eps, weight=(self.scale.float() + 1.0) + ) + return t.to(dtype) + + +class SwiGLU(torch.nn.Module): + def __init__( + self, features: int, multiplier: int, bias: bool = False, multiple: int = 128 + ): + super().__init__() + + mlpdim = int(2 * features / 3) * multiplier + mlpdim = multiple * ((mlpdim + multiple - 1) // multiple) + + self.gate = torch.nn.Linear(features, mlpdim, bias=bias) + self.up = torch.nn.Linear(features, mlpdim, bias=bias) + self.down = torch.nn.Linear(mlpdim, features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + return self.down(F.silu(self.gate(x)) * self.up(x)) + + +class Attention(torch.nn.Module): + def __init__(self, dim: int, heads: int, kvheads: int = None, bias: bool = False): + super().__init__() + self.heads = heads + self.kvheads = kvheads if kvheads is not None else heads + self.headdim = dim // self.heads + + self.wq = torch.nn.Linear(dim, self.headdim * self.heads, bias=bias) + self.wk = torch.nn.Linear(dim, self.headdim * self.kvheads, bias=bias) + self.wv = torch.nn.Linear(dim, self.headdim * self.kvheads, bias=bias) + self.gate = torch.nn.Linear(dim, dim, bias=bias) + self.qknorm = QKNorm(self.headdim) + self.gqa = self.heads != self.kvheads + self.wo = torch.nn.Linear(dim, dim, bias=bias) + + def forward( + self, qkv: Tensor, freqs: Tensor | None = None, mask: Tensor | None = None + ) -> Tensor: + q, k, v, gate = self.wq(qkv), self.wk(qkv), self.wv(qkv), self.gate(qkv) + + q, k, v = ( + rearrange(q, "B L (H D) -> B H L D", H=self.heads), + rearrange(k, "B L (H D) -> B H L D", H=self.kvheads), + rearrange(v, "B L (H D) -> B H L D", H=self.kvheads), + ) + + q, k, v = self.qknorm(q, k, v) + if freqs is not None: + q, k = ropeapply(q, k, freqs) + out = self.wo(attention(q, k, v, mask=mask, gqa=self.gqa) * F.sigmoid(gate)) + + return out + + +class LastLayer(torch.nn.Module): + def __init__(self, features: int, patch: int, channels: int): + super().__init__() + self.norm = RMSNorm(features) + self.linear = torch.nn.Linear(features, patch * patch * channels, bias=True) + self.modulation = SimpleModulation(features) + + def forward(self, x: Tensor, tvec: Tensor) -> Tensor: + scale, shift = self.modulation(tvec) + x = (1 + scale) * self.norm(x) + shift + x = self.linear(x) + return x + + +class TextFusionBlock(torch.nn.Module): + def __init__( + self, + features: int, + heads: int, + multiplier: int, + bias: bool = False, + kvheads: int = None, + ): + super().__init__() + self.prenorm = RMSNorm(features) + self.postnorm = RMSNorm(features) + self.attn = Attention(dim=features, heads=heads, bias=bias, kvheads=kvheads) + self.mlp = SwiGLU(features, multiplier, bias) + + def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor: + x = x + self.attn(self.prenorm(x), mask=mask) + x = x + self.mlp(self.postnorm(x)) + + return x + + +class TextFusionTransformer(torch.nn.Module): + # num_txt_layers is the number of selected encoder hidden-state layers fed in + # (projected down to 1), NOT the transformer depth — that's fixed at 2 + 2 blocks. + def __init__( + self, + num_txt_layers: int, + txt_dim: int, + heads: int, + multiplier: int, + bias: bool = False, + kvheads: int = None, + ): + super().__init__() + self.layerwise_blocks = torch.nn.ModuleList( + [ + TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads) + for _ in range(2) + ] + ) + self.projector = torch.nn.Linear(num_txt_layers, 1, bias=False) + self.refiner_blocks = torch.nn.ModuleList( + [ + TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads) + for _ in range(2) + ] + ) + + def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor: + b, l, n, d = x.shape + x = x.reshape(b * l, n, d) + for block in self.layerwise_blocks: + x = block(x.contiguous(), mask=None) + x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l) + # Collapse to 3D for the projector: a quantized (quanto) Linear's matmul + # kernel only accepts 2D/3D activations, and this layer-axis projection + # (n -> 1) otherwise feeds it a 4D (b, l, d, n) tensor. + x = self.projector(x.reshape(b * l, d, n)) + x = x.reshape(b, l, d) + + for block in self.refiner_blocks: + x = block(x, mask=mask) + + return x + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + features: int, + heads: int, + multiplier: int, + bias: bool = False, + kvheads: int = None, + ): + super().__init__() + self.mod = DoubleSharedModulation(features) + self.prenorm = RMSNorm(features) + self.postnorm = RMSNorm(features) + self.attn = Attention(dim=features, heads=heads, bias=bias, kvheads=kvheads) + self.mlp = SwiGLU(features, multiplier, bias) + + def forward( + self, x: Tensor, vec: Tensor, freqs: Tensor, mask: Tensor | None = None + ) -> Tensor: + prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec) + x = x + pregate * self.attn( + (1 + prescale) * self.prenorm(x) + preshift, freqs, mask + ) + x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift) + + return x + + +class SingleStreamDiT(nn.Module): + def __init__(self, config: SingleMMDiTConfig): + super().__init__() + self.config = config + self.gradient_checkpointing = False + + headdim = config.features // config.heads + axes = [ + headdim - 12 * (headdim // 16), + 6 * (headdim // 16), + 6 * (headdim // 16), + ] + assert sum(axes) == headdim, f"sum(axes) = {sum(axes)}, headdim = {headdim}" + assert all(a % 2 == 0 for a in axes), f"axes = {axes}" + + self.posemb = PositionalEncoding( + config.features, axes, theta=config.theta, ntk=1.0 + ) + self.first = nn.Linear( + config.channels * config.patch**2, config.features, bias=True + ) + + self.blocks = nn.ModuleList( + [ + SingleStreamBlock( + config.features, + config.heads, + config.multiplier, + config.bias, + config.kvheads, + ) + for _ in range(config.layers) + ] + ) + self.tmlp = nn.Sequential( + nn.Linear(config.tdim, config.features), + nn.GELU(approximate="tanh"), + nn.Linear(config.features, config.features), + ) + self.txtfusion = TextFusionTransformer( + config.txtlayers, + config.txtdim, + config.txtheads, + config.multiplier, + config.bias, + config.txtkvheads, + ) + self.txtmlp = nn.Sequential( + RMSNorm(config.txtdim), + nn.Linear(config.txtdim, config.features), + nn.GELU(approximate="tanh"), + nn.Linear(config.features, config.features), + ) + self.last = LastLayer(config.features, config.patch, config.channels) + + self.tproj = nn.Sequential( + nn.GELU(approximate="tanh"), nn.Linear(config.features, config.features * 6) + ) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def forward( + self, + img: Tensor, + context: Tensor, + t: Tensor, + pos: Tensor, + mask: Tensor | None = None, + ) -> Tensor: + img = self.first(img) + t = self.tmlp(temb(t, self.config.tdim, device=img.device, dtype=img.dtype)) + tvec = self.tproj(t) + + txtmask = _mask(mask[:, : context.shape[1]]) + + context = self.txtfusion(context, mask=txtmask) + context = self.txtmlp(context) + + txtlen, imglen = context.shape[1], img.shape[1] + combined = torch.cat((context, img), dim=1) + + # Pad combined sequence to a multiple of 256 to stabilize compiled kernel shapes. + fulllen = combined.shape[1] + _padlen = (-fulllen) % 256 + if _padlen > 0: + combined = F.pad(combined, (0, 0, 0, _padlen)) + mask = F.pad(mask, (0, _padlen), value=False) + pos = F.pad(pos, (0, 0, 0, _padlen)) + + mask = _mask(mask) + + freqs = self.posemb(pos) + + for block in self.blocks: + if self.gradient_checkpointing and torch.is_grad_enabled(): + combined = checkpoint( + block, + combined, + tvec, + freqs, + mask, + use_reentrant=False, + ) + else: + combined = block(combined, tvec, freqs, mask) + + final = self.last(combined, t) + output = final[:, txtlen : txtlen + imglen, :] + + return output diff --git a/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0345c58631cc364690235b11be5a5d3b24d2e337 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/pipeline.py @@ -0,0 +1,260 @@ +"""Packing / sampling helpers for Krea 2. + +Turns image latents + stacked Qwen3-VL text features into the single sequence the +``SingleStreamDiT`` consumes, and provides a minimal flow-matching sampler used to +render preview images during training. + +Time convention: Krea 2 is a plain flow-matching model whose time runs ``t=1`` +(pure noise) -> ``t=0`` (clean), the velocity it predicts is ``noise - clean``, +and ``x_t = (1 - t) * clean + t * noise``. This is *identical* to ai-toolkit's +convention, so unlike ideogram4 there is no flipping or negation -- the toolkit +``timestep / 1000`` flows straight through as ``t``. +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +from einops import rearrange, repeat +from PIL import Image +from diffusers.utils.torch_utils import randn_tensor + +from .mmdit import SingleStreamDiT + + +# --------------------------------------------------------------------------- +# Text feature padding. +# --------------------------------------------------------------------------- + + +def pad_text_features( + features_list: List[torch.Tensor], + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Right-pad a list of per-sample ``(Lt_i, F)`` features into a batch. + + Each caption is stored 2D at its natural length -- the 12 stacked Qwen3-VL + hidden-state layers are flattened into the feature axis ``F = n * d`` so the + ai-toolkit batching machinery treats the list length as the batch size (it + only special-cases 2D per-sample tensors). The layer axis is restored in + ``predict_velocity`` right before the MMDiT call. Padding to the batch max is + deferred to here. Returns ``(features (B, Lt, F), mask (B, Lt))``; the mask is + 1 for real text tokens and 0 for padding. + """ + lengths = [f.shape[0] for f in features_list] + max_len = max(lengths) + dim = features_list[0].shape[-1] + batch_size = len(features_list) + + features = torch.zeros(batch_size, max_len, dim, device=device, dtype=dtype) + mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) + for i, f in enumerate(features_list): + ln = f.shape[0] + features[i, :ln] = f.to(device, dtype) + mask[i, :ln] = 1 + return features, mask + + +# --------------------------------------------------------------------------- +# Latent <-> token packing and combined position / mask construction. +# --------------------------------------------------------------------------- + + +def prepare( + img: torch.Tensor, txtlen: int, patch: int, txtmask: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Patchify the latent and build the combined text+image position / mask. + + in: img (B, C, h, w) image latent + txtlen number of text tokens + patch transformer patch size + txtmask (B, txtlen) long/bool mask, 1 for real text tokens + out: (img_tokens (B, h/p*w/p, C*p*p), pos (B, txtlen+imglen, 3), + mask (B, txtlen+imglen)) + """ + b, _, h, w = img.shape + h_, w_ = h // patch, w // patch + imgids = torch.zeros((h_, w_, 3), device=img.device) + imgids[..., 1] = torch.arange(h_, device=img.device)[:, None] + imgids[..., 2] = torch.arange(w_, device=img.device)[None, :] + imgpos = repeat(imgids, "h w three -> b (h w) three", b=b, three=3) + imgmask = torch.ones(b, h_ * w_, device=img.device, dtype=torch.bool) + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch) + + txtpos = torch.zeros(b, txtlen, 3, device=img.device) + mask = torch.cat((txtmask.to(img.device).bool(), imgmask), dim=1) + pos = torch.cat((txtpos, imgpos), dim=1) + return img, pos, mask + + +def predict_velocity( + model: SingleStreamDiT, + latents: torch.Tensor, # (B, C, h, w) + t: torch.Tensor, # (B,) flow time in [0, 1] (1 = pure noise) + context: torch.Tensor, # (B, Lt, n*d) flattened stacked Qwen3-VL features + text_mask: torch.Tensor, # (B, Lt) 1 for real text tokens +) -> torch.Tensor: + """Run the MMDiT on the packed [text | image] sequence. + + ``latents`` stay in the unpacked ``(B, C, h, w)`` latent layout; image-token + packing is internal to this function. ``context`` arrives 2D-per-sample + flattened ``(B, Lt, n*d)`` and is restored to ``(B, Lt, n, d)`` for the MMDiT. + Returns the velocity ``noise - clean`` reshaped back to ``(B, C, h, w)``. No + time flip / negation: Krea's convention matches toolkit's. + """ + patch = model.config.patch + b, c, h, w = latents.shape + + # Restore the stacked-layer axis flattened in pad_text_features: F -> (n, d). + n = model.config.txtlayers + context = context.reshape( + context.shape[0], context.shape[1], n, context.shape[-1] // n + ) + + img_tokens, pos, mask = prepare(latents, context.shape[1], patch, text_mask) + + out = model(img=img_tokens, context=context, t=t, pos=pos, mask=mask) + + # (B, imglen, c*p*p) -> (B, c, h, w) + velocity = rearrange( + out, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + ph=patch, + pw=patch, + h=h // patch, + w=w // patch, + ) + return velocity + + +# --------------------------------------------------------------------------- +# Resolution-aware flow-matching timestep schedule. +# --------------------------------------------------------------------------- + + +def timesteps( + seq_len: int, + steps: int, + x1: float, + x2: float, + y1: float = 0.5, + y2: float = 1.15, + sigma: float = 1.0, + mu: Optional[float] = None, +) -> List[float]: + """Resolution-aware flow-matching timestep schedule (t: 1 -> 0). + + ``mu`` is interpolated linearly in image-sequence length between (x1, y1) and + (x2, y2), then used to time-shift a uniform 1->0 grid. Pass an explicit ``mu`` + to pin a constant shift regardless of resolution (the distilled turbo + checkpoint was trained at a fixed mu=1.15). + """ + ts = torch.linspace(1, 0, steps + 1) + if mu is None: + slope = (y2 - y1) / (x2 - x1) + mu = slope * seq_len + (y1 - slope * x1) + ts = math.exp(mu) / (math.exp(mu) + (1.0 / ts - 1.0) ** sigma) + return ts.tolist() + + +# --------------------------------------------------------------------------- +# Minimal sampling pipeline (for training previews). +# --------------------------------------------------------------------------- + + +class Krea2Pipeline: + """Lightweight flow-matching sampler used by ai-toolkit's preview generation.""" + + def __init__(self, model): + # ``model`` is the Krea2Model so we can reuse its encode/decode and config. + self.model = model + + @property + def device(self): + return self.model.device_torch + + def to(self, *args, **kwargs): + return self + + def set_progress_bar_config(self, **kwargs): + pass + + @torch.no_grad() + def __call__( + self, + conditional_embeds, + unconditional_embeds, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 28, + guidance_scale: float = 4.5, + latents: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[Image.Image]: + model = self.model + device = model.device_torch + dtype = model.torch_dtype + transformer: SingleStreamDiT = model.transformer + patch = model.patch_size + ae_scale = model.vae_scale_factor # 8 + + mkw = model.model_config.model_kwargs + y1 = float(mkw.get("schedule_y1", 0.5)) + y2 = float(mkw.get("schedule_y2", 1.15)) + minres = int(mkw.get("schedule_min_res", 256)) + maxres = int(mkw.get("schedule_max_res", 1280)) + mu = mkw.get("schedule_mu", None) + mu = float(mu) if mu is not None else None + + do_cfg = guidance_scale > 0 and unconditional_embeds is not None + + gh = height // (ae_scale * patch) + gw = width // (ae_scale * patch) + latent_channels = transformer.config.channels + + # Starting gaussian noise in the (B, C, h8, w8) latent layout. + if latents is None: + shape = (1, latent_channels, height // ae_scale, width // ae_scale) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=torch.float32 + ) + latents = latents.to(device, dtype=torch.float32) + + cond_feats, cond_mask = pad_text_features( + conditional_embeds.text_embeds, device, dtype + ) + if do_cfg: + uncond_feats, uncond_mask = pad_text_features( + unconditional_embeds.text_embeds, device, dtype + ) + + # min_res / max_res define the (x1,y1)-(x2,y2) interpolation endpoints for mu. + align = ae_scale * patch + x1 = (minres // align) ** 2 + x2 = (maxres // align) ** 2 + ts = timesteps(gh * gw, num_inference_steps, x1, x2, y1=y1, y2=y2, mu=mu) + + # Euler integration of the flow ODE (with optional CFG). + for tcurr, tprev in zip(ts[:-1], ts[1:]): + t = torch.full((latents.shape[0],), tcurr, dtype=dtype, device=device) + v_cond = predict_velocity( + transformer, latents.to(dtype), t, cond_feats, cond_mask + ) + if do_cfg: + v_uncond = predict_velocity( + transformer, latents.to(dtype), t, uncond_feats, uncond_mask + ) + v = v_cond + guidance_scale * (v_cond - v_uncond) + else: + v = v_cond + latents = latents + (tprev - tcurr) * v.to(torch.float32) + + images = model.decode_latents(latents, device=device, dtype=dtype) + images = images.float().clamp(-1.0, 1.0) + images = ((images + 1.0) * 127.5).round().to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy() + return [Image.fromarray(arr) for arr in images] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/text_encoder.py b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3baa445e80db9e06a3530c1705f352042297b6ef --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/krea2/src/text_encoder.py @@ -0,0 +1,84 @@ +"""Qwen3-VL text conditioning for Krea 2. + +Vendored / adapted from the reference ``encoder.py``. Krea 2 conditions on a +*stack* of hidden states pulled from several layers of Qwen3-VL-4B-Instruct +(``SELECT_LAYERS``), wrapped in a fixed instruction template. The MMDiT's +``TextFusionTransformer`` later collapses that layer axis down to one. + +The reference encodes a whole batch padded to ``max_length``; here we encode one +prompt at a time at its natural length (the ai-toolkit pattern -- caches stay +small, any prompts can share a batch, and per-sample padding is deferred to the +model call). The fixed instruction prefix is fed through the model as context but +its hidden states are sliced off the returned features, exactly like the +reference. +""" + +import torch +from torch import Tensor + + +# Layers of Qwen3-VL whose hidden states are stacked and fed to the MMDiT (12). +SELECT_LAYERS: tuple[int, ...] = (2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35) + +# Fixed instruction template wrapped around every prompt. The prefix is fed +# through the model as context but its hidden states are dropped from the output +# (the assistant only ever sees the prompt + suffix tokens as conditioning). +PROMPT_TEMPLATE_ENCODE_PREFIX = ( + "<|im_start|>system\nDescribe the image by detailing the color, shape, size, " + "texture, quantity, text, spatial relationships of the objects and " + "background:<|im_end|>\n<|im_start|>user\n" +) +PROMPT_TEMPLATE_ENCODE_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n" + +# Number of leading tokens (the system prefix) sliced off the encoded features. +PROMPT_TEMPLATE_ENCODE_START_IDX = 34 + + +@torch.no_grad() +def encode_krea_prompt( + qwen, + tokenizer, + processor, + prompt: str, + max_length: int = 512, + select_layers: tuple[int, ...] = SELECT_LAYERS, + prefix_idx: int = PROMPT_TEMPLATE_ENCODE_START_IDX, +) -> Tensor: + """Encode a single prompt into stacked Qwen3-VL hidden states. + + Returns a ``(L, num_select_layers, hidden)`` float tensor (in the encoder's + dtype) holding the prompt + suffix token features -- the system prefix has + been sliced off. ``L`` is the natural (unpadded) length so the caller stores + one tensor per prompt and pads to the batch max at the model call. + """ + device = qwen.device + + # The suffix ("...assistant\n") is tokenized without the BOS/template extras + # the main tokenizer adds, matching the reference's separate processor pass. + suffix_inputs = processor( + text=[PROMPT_TEMPLATE_ENCODE_SUFFIX], return_tensors="pt" + ).to(device, non_blocking=True) + suffix_ids = suffix_inputs["input_ids"] + suffix_mask = suffix_inputs["attention_mask"].bool() + + # Prefix + prompt at natural length (no padding); truncate very long prompts. + text = PROMPT_TEMPLATE_ENCODE_PREFIX + prompt + inputs = tokenizer( + [text], + truncation=True, + return_length=False, + return_overflowing_tokens=False, + max_length=max_length + prefix_idx, + return_tensors="pt", + ).to(device, non_blocking=True) + + input_ids = torch.cat([inputs["input_ids"], suffix_ids], dim=1) + mask = torch.cat([inputs["attention_mask"].bool(), suffix_mask], dim=1) + + states = qwen(input_ids=input_ids, attention_mask=mask, output_hidden_states=True) + + # (1, L, num_layers, hidden) + hiddens = torch.stack([states.hidden_states[i] for i in select_layers], dim=2) + # Drop the system-prefix tokens; what remains is prompt + suffix conditioning. + hiddens = hiddens[:, prefix_idx:] + return hiddens[0] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ltx2/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/ltx2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44b8407370ba996a21fe156254da0eea91285e25 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ltx2/__init__.py @@ -0,0 +1 @@ +from .ltx2 import LTX2Model, LTX23Model \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py b/ai-toolkit/extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6447ba5060e0d20dc0a2606646f45d87b14c3d --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ltx2/convert_ltx2_to_diffusers.py @@ -0,0 +1,1044 @@ +# ref https://github.com/huggingface/diffusers/blob/17b53f08661732caca6a546295950fc4b1696ad7/scripts/convert_ltx2_to_diffusers.py + +from contextlib import nullcontext +from typing import Any, Dict, Tuple + +import torch +from accelerate import init_empty_weights + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulatin Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = { + **LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT, + "audio_prompt_adaln_single": "audio_prompt_adaln", + "prompt_adaln_single": "prompt_adaln", +} + +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_3_VIDEO_VAE_RENAME_DICT = { + **LTX_2_0_VIDEO_VAE_RENAME_DICT, + # Decoder extra blocks + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", +} + +LTX_2_0_AUDIO_VAE_RENAME_DICT = { + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + +LTX_2_3_VOCODER_RENAME_DICT = { + # Handle upsamplers ("ups" --> "upsamplers") due to name clash + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", + "act_post": "act_out", + "downsample.lowpass": "downsample", +} + +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + # LTX-2.3 uses per-modality embedding projections + "text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in", + "text_embedding_projection.video_aggregate_embed": "video_text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + + +def update_state_dict_inplace( + state_dict: Dict[str, Any], old_key: str, new_key: str +) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None: + state_dict.pop(key) + + +def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +def convert_ltx2_audio_vae_per_channel_statistics( + key: str, state_dict: Dict[str, Any] +) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if ".ups." in key: + new_key = key.replace(".ups.", ".upsamplers.") + param = state_dict.pop(key) + state_dict[new_key] = param + return + + +LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, +} + +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {} + +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + +LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = { + ".ups.": convert_ltx2_3_vocoder_upsamplers, +} + +LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {} + + +def split_transformer_and_connector_state_dict( + state_dict: Dict[str, Any], +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + connector_prefixes = ( + "video_embeddings_connector", + "audio_embeddings_connector", + "transformer_1d_blocks", + "text_embedding_projection", + "connectors.", + "video_connector", + "audio_connector", + "text_proj_in", + ) + + transformer_state_dict, connector_state_dict = {}, {} + for key, value in state_dict.items(): + if key.startswith(connector_prefixes): + connector_state_dict[key] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, connector_state_dict + + +def get_ltx2_transformer_config( + version: str = "2.0", +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + # Produces a transformer of the same size as used in test_models_transformer_ltx2.py + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 2, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 16, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "Lightricks/LTX-2", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "gated_attn": False, + "cross_attn_mod": False, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "audio_gated_attn": False, + "audio_cross_attn_mod": False, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + "use_prompt_embeddings": True, + "perturbed_attn": False, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "gated_attn": True, + "cross_attn_mod": True, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "audio_gated_attn": True, + "audio_cross_attn_mod": True, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + "use_prompt_embeddings": False, + "perturbed_attn": True, + }, + } + rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def get_ltx2_connectors_config( + version: str = "2.0", +) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "caption_channels": 16, + "text_proj_in_factor": 3, + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + }, + } + elif version == "2.0": + config = { + "model_id": "Lightricks/LTX-2", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "video_gated_attn": False, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "audio_gated_attn": False, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + "per_modality_projections": False, + "proj_bias": False, + }, + } + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 32, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 8, + "video_connector_num_learnable_registers": 128, + "video_gated_attn": True, + "audio_connector_num_attention_heads": 32, + "audio_connector_attention_head_dim": 64, + "audio_connector_num_layers": 8, + "audio_connector_num_learnable_registers": 128, + "audio_gated_attn": True, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + "per_modality_projections": True, + "video_hidden_dim": 4096, + "audio_hidden_dim": 2048, + "proj_bias": True, + }, + } + rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP + + return config, rename_dict, special_keys_remap + + +def convert_ltx2_transformer( + original_state_dict: Dict[str, Any], version: str = "2.0" +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) + diffusers_config = config["diffusers_config"] + + transformer_state_dict, _ = split_transformer_and_connector_state_dict( + original_state_dict + ) + + with init_empty_weights(): + transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(transformer_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(transformer_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(transformer_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, transformer_state_dict) + + transformer.load_state_dict(transformer_state_dict, strict=True, assign=True) + return transformer + + +def convert_ltx2_connectors( + original_state_dict: Dict[str, Any], version: str = "2.0" +) -> LTX2TextConnectors: + config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version) + diffusers_config = config["diffusers_config"] + + _, connector_state_dict = split_transformer_and_connector_state_dict( + original_state_dict + ) + if len(connector_state_dict) == 0: + raise ValueError("No connector weights found in the provided state dict.") + + with init_empty_weights(): + connectors = LTX2TextConnectors.from_config(diffusers_config) + + for key in list(connector_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(connector_state_dict, key, new_key) + + for key in list(connector_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, connector_state_dict) + + connectors.load_state_dict(connector_state_dict, strict=True, assign=True) + return connectors + + +def get_ltx2_video_vae_config( + version: str = "2.0", timestep_conditioning: bool = False +) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ( + "spatial", + "temporal", + "spatiotemporal", + "spatiotemporal", + ), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "Lightricks/LTX-2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ( + "spatial", + "temporal", + "spatiotemporal", + "spatiotemporal", + ), + "upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 1024), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 512, 1024), + "layers_per_block": (4, 6, 4, 2, 2), + "decoder_layers_per_block": (4, 6, 4, 2, 2), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True, True), + "decoder_inject_noise": (False, False, False, False, False), + "downsample_type": ( + "spatial", + "temporal", + "spatiotemporal", + "spatiotemporal", + ), + "upsample_type": ( + "spatiotemporal", + "spatiotemporal", + "temporal", + "spatial", + ), + "upsample_residual": (False, False, False, False), + "upsample_factor": (2, 2, 1, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "zeros", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], + version: str = "2.0", + timestep_conditioning: bool = False, +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config( + version, timestep_conditioning + ) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_audio_vae_config( + version: str = "2.0", +) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "Lightricks/LTX-2", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, # Same config as LTX-2.0 + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae( + original_state_dict: Dict[str, Any], version: str = "2.0" +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_vocoder_config( + version: str = "2.0", +) -> tuple[Dict[str, Any], dict[str, Any], dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "Lightricks/LTX-2", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "act_fn": "leaky_relu", + "leaky_relu_negative_slope": 0.1, + "antialias": False, + "final_act_fn": "tanh", + "final_bias": True, + "output_sampling_rate": 24000, + }, + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + elif version == "2.3": + config = { + "model_id": "Lightricks/LTX-2.3", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1536, + "out_channels": 2, + "upsample_kernel_sizes": [11, 4, 4, 4, 4, 4], + "upsample_factors": [5, 2, 2, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "act_fn": "snakebeta", + "leaky_relu_negative_slope": 0.1, + "antialias": True, + "antialias_ratio": 2, + "antialias_kernel_size": 12, + "final_act_fn": None, + "final_bias": False, + "bwe_in_channels": 128, + "bwe_hidden_channels": 512, + "bwe_out_channels": 2, + "bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4], + "bwe_upsample_factors": [6, 5, 2, 2, 2], + "bwe_resnet_kernel_sizes": [3, 7, 11], + "bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "bwe_act_fn": "snakebeta", + "bwe_leaky_relu_negative_slope": 0.1, + "bwe_antialias": True, + "bwe_antialias_ratio": 2, + "bwe_antialias_kernel_size": 12, + "bwe_final_act_fn": None, + "bwe_final_bias": False, + "filter_length": 512, + "hop_length": 80, + "window_length": 512, + "num_mel_channels": 64, + "input_sampling_rate": 16000, + "output_sampling_rate": 48000, + }, + } + rename_dict = LTX_2_3_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder( + original_state_dict: Dict[str, Any], version: str +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) + diffusers_config = config["diffusers_config"] + if version == "2.3": + vocoder_cls = LTX2VocoderWithBWE + else: + vocoder_cls = LTX2Vocoder + + with init_empty_weights(): + vocoder = vocoder_cls.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + +def get_model_state_dict_from_combined_ckpt( + combined_ckpt: Dict[str, Any], prefix: str +) -> Dict[str, Any]: + # Ensure that the key prefix ends with a dot (.) + if not prefix.endswith("."): + prefix = prefix + "." + + model_state_dict = {} + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + model_state_dict[param_name.removeprefix(prefix)] = param + + if prefix == "model.diffusion_model.": + # Some checkpoints store the text connector projection outside the diffusion model prefix. + connector_prefixes = ["text_embedding_projection"] + for param_name, param in combined_ckpt.items(): + for prefix in connector_prefixes: + if param_name.startswith(prefix): + # Check to make sure we're not overwriting an existing key + if param_name not in model_state_dict: + model_state_dict[param_name] = combined_ckpt[param_name] + + return model_state_dict + + +def dequantize_state_dict(state_dict: Dict[str, Any]): + keys = list(state_dict.keys()) + state_out = {} + for k in keys: + if k.endswith( + (".weight_scale", ".weight_scale_2", ".pre_quant_scale", ".input_scale") + ): + continue + + t = state_dict[k] + + if k.endswith(".weight"): + prefix = k[: -len(".weight")] + wscale_k = prefix + ".weight_scale" + if wscale_k in state_dict: + w_q = t + w_scale = state_dict[wscale_k] + # Comfy quant = absmax per-tensor weight quant, nothing fancy + w_bf16 = w_q.to(torch.bfloat16) * w_scale.to(torch.bfloat16) + state_out[k] = w_bf16 + continue + + state_out[k] = t + return state_out + + +def convert_comfy_gemma3_to_transformers(sd: dict): + out = {} + + sd = dequantize_state_dict(sd) + + for k, v in sd.items(): + nk = k + + # Vision tower weights: checkpoint has "vision_model.*" + # model expects "model.vision_tower.vision_model.*" + if k.startswith("vision_model."): + nk = "model.vision_tower." + k + + # MM projector: checkpoint has "multi_modal_projector.*" + # model expects "model.multi_modal_projector.*" + elif k.startswith("multi_modal_projector."): + nk = "model." + k + + # Language model: checkpoint has "model.layers.*", "model.embed_tokens.*", "model.norm.*" + # model expects "model.language_model.layers.*", etc. + elif k == "model.embed_tokens.weight": + nk = "model.language_model.embed_tokens.weight" + elif k.startswith("model.layers."): + nk = "model.language_model.layers." + k[len("model.layers.") :] + elif k.startswith("model.norm."): + nk = "model.language_model.norm." + k[len("model.norm.") :] + + # (optional) common DDP prefix + if nk.startswith("module."): + nk = nk[len("module.") :] + + # skip spiece_model + if nk == "spiece_model": + continue + + out[nk] = v + + # If lm_head is missing but embeddings exist, many Gemma-family models tie these weights. + # Add it so strict loading won't complain (or just load strict=False and call tie_weights()). + if ( + "lm_head.weight" not in out + and "model.language_model.embed_tokens.weight" in out + ): + out["lm_head.weight"] = out["model.language_model.embed_tokens.weight"] + + return out + + +def convert_lora_original_to_diffusers( + lora_state_dict: Dict[str, Any], + version: str = "2.0", +) -> Dict[str, Any]: + out: Dict[str, Any] = {} + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + if version == "2.3": + rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT + + for k, v in lora_state_dict.items(): + # Keep the "diffusion_model." prefix as-is, but apply the transformer remaps to the rest + prefix = "" + rest = k + if rest.startswith("diffusion_model."): + prefix = "diffusion_model." + rest = rest[len(prefix) :] + + nk = rest + + # Same simple 1:1 remaps as the transformer + for replace_key, rename_key in rename_dict.items(): + nk = nk.replace(replace_key, rename_key) + + # Same special-case remap as the transformer (applies to LoRA keys too) + if nk.startswith("adaln_single."): + nk = nk.replace("adaln_single.", "time_embed.", 1) + elif nk.startswith("audio_adaln_single."): + nk = nk.replace("audio_adaln_single.", "audio_time_embed.", 1) + + out[prefix + nk] = v + + return out + + +def convert_lora_diffusers_to_original( + lora_state_dict: Dict[str, Any], + version: str = "2.0", +) -> Dict[str, Any]: + out: Dict[str, Any] = {} + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + if version == "2.3": + rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT + + inv_rename = {v: k for k, v in rename_dict.items()} + inv_items = sorted(inv_rename.items(), key=lambda kv: len(kv[0]), reverse=True) + + for k, v in lora_state_dict.items(): + # Keep the "diffusion_model." prefix as-is, but invert remaps on the rest + prefix = "" + rest = k + if rest.startswith("diffusion_model."): + prefix = "diffusion_model." + rest = rest[len(prefix) :] + + nk = rest + + # Inverse of the adaln_single special-case + if nk.startswith("time_embed."): + nk = nk.replace("time_embed.", "adaln_single.", 1) + elif nk.startswith("audio_time_embed."): + nk = nk.replace("audio_time_embed.", "audio_adaln_single.", 1) + + # Inverse 1:1 remaps + for diffusers_key, original_key in inv_items: + nk = nk.replace(diffusers_key, original_key) + + out[prefix + nk] = v + + return out diff --git a/ai-toolkit/extensions_built_in/diffusion_models/ltx2/ltx2.py b/ai-toolkit/extensions_built_in/diffusion_models/ltx2/ltx2.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1465cb63ee8fda44d19c30604a8573eaf5c42f --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/ltx2/ltx2.py @@ -0,0 +1,1149 @@ +from functools import partial +import os +from typing import List, Optional + +import torch +import torchaudio +from transformers import Gemma3Config +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from accelerate import init_empty_weights +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file +from PIL import Image +import huggingface_hub + +try: + from diffusers import LTX2Pipeline, LTX2ImageToVideoPipeline + from diffusers.models.autoencoders import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + ) + from diffusers.models.transformers import LTX2VideoTransformer3DModel + from diffusers.pipelines.ltx2.export_utils import encode_video + from transformers import ( + Gemma3ForConditionalGeneration, + GemmaTokenizerFast, + ) + from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder, LTX2VocoderWithBWE + from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors + from .convert_ltx2_to_diffusers import ( + get_model_state_dict_from_combined_ckpt, + convert_ltx2_transformer, + convert_ltx2_video_vae, + convert_ltx2_audio_vae, + convert_ltx2_vocoder, + convert_ltx2_connectors, + dequantize_state_dict, + convert_comfy_gemma3_to_transformers, + convert_lora_original_to_diffusers, + convert_lora_diffusers_to_original, + ) +except ImportError as e: + print("Diffusers import error:", e) + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +scheduler_config = { + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} + +dit_prefix = "model.diffusion_model." +vae_prefix = "vae." +audio_vae_prefix = "audio_vae." +vocoder_prefix = "vocoder." +base_te_path = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized" + +HF_TOKEN = os.getenv("HF_TOKEN", None) + + +def new_save_image_function( + self: GenerateImageConfig, + image, # will contain a dict that can be dumped ditectly into encode_video, just add output_path to it. + count: int = 0, + max_count: int = 0, + **kwargs, +): + # this replaces gen image config save image function so we can save the video with sound from ltx2 + image["output_path"] = self.get_image_path(count, max_count) + # make sample directory if it does not exist + os.makedirs(os.path.dirname(image["output_path"]), exist_ok=True) + encode_video(**image) + flush() + + +def blank_log_image_function(self, *args, **kwargs): + # todo handle wandb logging of videos with audio + return + + +class ComboVae(torch.nn.Module): + """Combines video and audio VAEs for joint encoding and decoding.""" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + ) -> None: + super().__init__() + self.vae = vae + self.audio_vae = audio_vae + + @property + def device(self): + return self.vae.device + + @property + def dtype(self): + return self.vae.dtype + + @property + def config(self): + return self.vae.config + + def encode( + self, + *args, + **kwargs, + ): + return self.vae.encode(*args, **kwargs) + + def decode( + self, + *args, + **kwargs, + ): + return self.vae.decode(*args, **kwargs) + + +class AudioProcessor(torch.nn.Module): + """Converts audio waveforms to log-mel spectrograms with optional resampling.""" + + def __init__( + self, + sample_rate: int, + mel_bins: int, + mel_hop_length: int, + n_fft: int, + ) -> None: + super().__init__() + self.sample_rate = sample_rate + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + win_length=n_fft, + hop_length=mel_hop_length, + f_min=0.0, + f_max=sample_rate / 2.0, + n_mels=mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ) + + def resample_waveform( + self, + waveform: torch.Tensor, + source_rate: int, + target_rate: int, + ) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(device=waveform.device, dtype=waveform.dtype) + + def waveform_to_mel( + self, + waveform: torch.Tensor, + waveform_sample_rate: int, + ) -> torch.Tensor: + """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" + waveform = self.resample_waveform( + waveform, waveform_sample_rate, self.sample_rate + ) + + mel = self.mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + + mel = mel.to(device=waveform.device, dtype=waveform.dtype) + return mel.permute(0, 1, 3, 2).contiguous() + + +class LTX2Model(BaseModel): + arch = "ltx2" + ltx_version = "2.0" + ltx_te_path = None + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["LTX2VideoTransformer3DModel"] + # defines if the model supports model paths. Only some will + self.supports_model_paths = True + # use the new format on this new model by default + self.use_old_lokr_format = False + self.audio_processor = None + + # gemma needs left side padding + self.te_padding_side = "left" + + # invalidate older caches + self.latent_space_version = f"{self.arch}_v2" + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 32 + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading LTX2 model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + combined_state_dict = None + + self.print_and_status_update("Loading transformer") + + if not os.path.exists(model_path) and model_path.endswith(".safetensors"): + # download the model from the Hugging Face Hub if it is not a local path + splits = model_path.split("/") + if len(splits) != 3: + raise ValueError( + f"Invalid model path: {model_path}. Must be in the format 'repo_id/repo/filename.safetensors' to download from the Hugging Face Hub." + ) + # download the model from the hub + model_path = huggingface_hub.hf_hub_download( + repo_id="/".join(splits[:2]), + filename=splits[2], + token=HF_TOKEN, + ) + + # if we have a safetensors file it is a mono checkpoint + if os.path.exists(model_path) and model_path.endswith(".safetensors"): + combined_state_dict = load_file(model_path) + combined_state_dict = dequantize_state_dict(combined_state_dict) + + if combined_state_dict is not None: + original_dit_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, dit_prefix + ) + transformer = convert_ltx2_transformer( + original_dit_ckpt, version=self.ltx_version + ) + transformer = transformer.to(dtype) + else: + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = LTX2VideoTransformer3DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + ignore_modules = [] + for block in transformer.transformer_blocks: + ignore_modules.append(block.scale_shift_table) + ignore_modules.append(block.audio_scale_shift_table) + ignore_modules.append(block.video_a2v_cross_attn_scale_shift_table) + ignore_modules.append(block.audio_a2v_cross_attn_scale_shift_table) + ignore_modules.append(transformer.scale_shift_table) + ignore_modules.append(transformer.audio_scale_shift_table) + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=ignore_modules, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Loading text encoder") + if ( + self.model_config.te_name_or_path is not None + and self.model_config.te_name_or_path.endswith(".safetensors") + ): + # load from comfyui gemma3 checkpoint + tokenizer = GemmaTokenizerFast.from_pretrained(base_te_path) + + with init_empty_weights(): + text_encoder = Gemma3ForConditionalGeneration( + Gemma3Config( + **{ + "boi_token_index": 255999, + "bos_token_id": 2, + "eoi_token_index": 256000, + "eos_token_id": 106, + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "pad_token_id": 0, + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": None, + "cache_implementation": "hybrid", + "final_logit_softcapping": None, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 3840, + "initializer_range": 0.02, + "intermediate_size": 15360, + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 16, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": {"factor": 8.0, "rope_type": "linear"}, + "rope_theta": 1000000, + "sliding_window": 1024, + "sliding_window_pattern": 6, + "torch_dtype": "bfloat16", + "use_cache": True, + "vocab_size": 262208, + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.3", + "unsloth_fixed": True, + "vision_config": { + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "torch_dtype": "bfloat16", + "vision_use_head": False, + }, + } + ) + ) + te_state_dict = load_file(self.model_config.te_name_or_path) + te_state_dict = convert_comfy_gemma3_to_transformers(te_state_dict) + for key in te_state_dict: + te_state_dict[key] = te_state_dict[key].to(dtype) + + text_encoder.load_state_dict(te_state_dict, assign=True, strict=True) + del te_state_dict + flush() + elif self.model_config.te_name_or_path is not None: + # a repo or folder + tokenizer = GemmaTokenizerFast.from_pretrained( + self.model_config.te_name_or_path + ) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + self.model_config.te_name_or_path, dtype=dtype + ) + elif self.ltx_te_path is not None: + # pull from model specific te + tokenizer = GemmaTokenizerFast.from_pretrained(self.ltx_te_path) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + self.ltx_te_path, dtype=dtype + ) + else: + # using combo hf repo + tokenizer = GemmaTokenizerFast.from_pretrained( + self.model_config.name_or_path, subfolder="tokenizer" + ) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + self.model_config.name_or_path, subfolder="text_encoder", dtype=dtype + ) + + # remove the vision tower + text_encoder.model.vision_tower = None + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ignore_modules=[ + text_encoder.model.language_model.base_model.embed_tokens + ], + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + self.print_and_status_update("Loading VAEs and other components") + if combined_state_dict is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, vae_prefix + ) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=self.ltx_version + ).to(dtype) + del original_vae_ckpt + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, audio_vae_prefix + ) + audio_vae = convert_ltx2_audio_vae( + original_audio_vae_ckpt, version=self.ltx_version + ).to(dtype) + del original_audio_vae_ckpt + original_connectors_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, dit_prefix + ) + connectors = convert_ltx2_connectors( + original_connectors_ckpt, version=self.ltx_version + ).to(dtype) + del original_connectors_ckpt + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt( + combined_state_dict, vocoder_prefix + ) + vocoder = convert_ltx2_vocoder( + original_vocoder_ckpt, version=self.ltx_version + ).to(dtype) + del original_vocoder_ckpt + del combined_state_dict + flush() + else: + vae = AutoencoderKLLTX2Video.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) + audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + base_model_path, subfolder="audio_vae", torch_dtype=dtype + ) + + connectors = LTX2TextConnectors.from_pretrained( + base_model_path, subfolder="connectors", torch_dtype=dtype + ) + + vocoder_cls = LTX2Vocoder + if self.ltx_version == "2.3": + vocoder_cls = LTX2VocoderWithBWE + + vocoder = vocoder_cls.from_pretrained( + base_model_path, subfolder="vocoder", torch_dtype=dtype + ) + + self.noise_scheduler = LTX2Model.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: LTX2Pipeline = LTX2Pipeline( + scheduler=self.noise_scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=None, + tokenizer=tokenizer, + connectors=connectors, + transformer=None, + vocoder=vocoder, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = ComboVae(pipe.vae, pipe.audio_vae) + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + + self.audio_processor = AudioProcessor( + sample_rate=pipe.audio_sampling_rate, + mel_bins=audio_vae.config.mel_bins, + mel_hop_length=pipe.audio_hop_length, + n_fft=1024, # todo get this from vae if we can, I couldnt find it. + ).to(self.device_torch, dtype=torch.float32) + + self.print_and_status_update("Model Loaded") + + @torch.no_grad() + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.pipeline.vae.device == torch.device("cpu"): + self.pipeline.vae.to(device) + self.pipeline.vae.eval() + self.pipeline.vae.requires_grad_(False) + + if self.model_config.low_vram: + self.pipeline.vae.tile_sample_min_num_frames = 64 + self.pipeline.vae.tile_sample_stride_num_frames = 16 + # they check the wrong flat on encode currently so set both to future proof + self.pipeline.vae.use_framewise_decoding = True + self.pipeline.vae.use_framewise_encoding = True + + image_list = [image.to(device, dtype=dtype) for image in image_list] + + # Normalize shapes + norm_images = [] + for image in image_list: + if image.ndim == 3: + # (C, H, W) -> (C, 1, H, W) + norm_images.append(image.unsqueeze(1)) + elif image.ndim == 4: + # (T, C, H, W) -> (C, T, H, W) + norm_images.append(image.permute(1, 0, 2, 3)) + else: + raise ValueError(f"Invalid image shape: {image.shape}") + + # Stack to (B, C, T, H, W) + images = torch.stack(norm_images) + + latents = self.pipeline.vae.encode(images).latent_dist.sample() + + # Normalize latents across the channel dimension [B, C, F, H, W] + scaling_factor = 1.0 + latents_mean = self.pipeline.vae.latents_mean.view(1, -1, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents_std = self.pipeline.vae.latents_std.view(1, -1, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = (latents - latents_mean) * scaling_factor / latents_std + + if self.model_config.low_vram: + self.pipeline.vae.use_framewise_decoding = False + self.pipeline.vae.use_framewise_encoding = False + + return latents.to(device, dtype=dtype) + + def get_generation_pipeline(self): + scheduler = LTX2Model.get_train_scheduler() + + pipeline: LTX2Pipeline = LTX2Pipeline( + scheduler=scheduler, + vae=unwrap_model(self.pipeline.vae), + audio_vae=unwrap_model(self.pipeline.audio_vae), + text_encoder=None, + tokenizer=unwrap_model(self.pipeline.tokenizer), + connectors=unwrap_model(self.pipeline.connectors), + transformer=None, + vocoder=unwrap_model(self.pipeline.vocoder), + ) + pipeline.transformer = unwrap_model(self.model) + pipeline.text_encoder = unwrap_model(self.text_encoder[0]) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: LTX2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # handle control image + if gen_config.ctrl_img is not None: + # switch to image to video pipeline + pipeline = LTX2ImageToVideoPipeline( + scheduler=pipeline.scheduler, + vae=pipeline.vae, + audio_vae=pipeline.audio_vae, + text_encoder=pipeline.text_encoder, + tokenizer=pipeline.tokenizer, + connectors=pipeline.connectors, + transformer=pipeline.transformer, + vocoder=pipeline.vocoder, + ) + + is_video = gen_config.num_frames > 1 + # override the generate single image to handle video + audio generation + if is_video: + gen_config._orig_save_image_function = gen_config.save_image + gen_config.save_image = partial(new_save_image_function, gen_config) + gen_config.log_image = partial(blank_log_image_function, gen_config) + # set output extension to mp4 + gen_config.output_ext = "mp4" + + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + pipeline = pipeline.to(self.device_torch) + + # make sure dimensions are valid + bd = self.get_bucket_divisibility() + gen_config.height = (gen_config.height // bd) * bd + gen_config.width = (gen_config.width // bd) * bd + + # handle control image + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + # resize the control image + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.LANCZOS + ) + # add the control image to the extra dict + extra["image"] = control_img + + # frames must be divisible by 8 then + 1. so 1, 9, 17, 25, etc. + if gen_config.num_frames != 1: + if (gen_config.num_frames - 1) % 8 != 0: + gen_config.num_frames = ((gen_config.num_frames - 1) // 8) * 8 + 1 + + if self.low_vram: + # set vae to tile decode + # pipeline.vae.enable_tiling( + # tile_sample_min_height=256, + # tile_sample_min_width=256, + # tile_sample_min_num_frames=8, + # tile_sample_stride_height=224, + # tile_sample_stride_width=224, + # tile_sample_stride_num_frames=4, + # ) + self.pipeline.vae.tile_sample_min_num_frames = 16 + self.pipeline.vae.tile_sample_stride_num_frames = 8 + self.pipeline.vae.use_framewise_decoding = True + + # We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2 + conditional_embeds = self.pad_embeds(conditional_embeds) + unconditional_embeds = self.pad_embeds(unconditional_embeds) + + if self.ltx_version == "2.3": + extra["stg_scale"] = 1.0 + extra["modality_scale"] = 3.0 + extra["guidance_rescale"] = 0.7 + extra["audio_guidance_scale"] = 7.0 + extra["audio_stg_scale"] = 1.0 + extra["audio_modality_scale"] = 3.0 + extra["audio_guidance_rescale"] = 0.7 + extra["spatio_temporal_guidance_blocks"] = [28] + extra["use_cross_timestep"] = ( + True # they dont set this in some examples in diffusers, but I believe it should always be true for 2.3 + ) + + video, audio = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + prompt_attention_mask=conditional_embeds.attention_mask.to( + self.device_torch + ), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to( + self.device_torch + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="np" if is_video else "pil", + **extra, + ) + if self.low_vram: + # Restore no tiling + # pipeline.vae.use_tiling = False + self.pipeline.vae.use_framewise_decoding = False + + if is_video: + # redurn as a dict, we will handle it with an override function + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + return { + "video": video[0], + "fps": gen_config.fps, + "audio": audio[0].float().cpu(), + "audio_sample_rate": pipeline.vocoder.config.output_sampling_rate, # should be 24000 + "output_path": None, + } + else: + # shape = [1, frames, channels, height, width] + # make sure this is right + video = video[0] # list of pil images + audio = audio[0] # tensor + if gen_config.num_frames > 1: + return video # return the frames. + else: + # get just the first image + img = video[0] + return img + + def encode_audio(self, audio_data_list): + # audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)} + if self.pipeline.audio_vae.device == torch.device("cpu"): + self.pipeline.audio_vae.to(self.device_torch) + + output_tensor = None + audio_num_frames = None + + # do them seperatly for now + for audio_data in audio_data_list: + waveform = audio_data["waveform"].to( + device=self.device_torch, dtype=torch.float32 + ) + sample_rate = audio_data["sample_rate"] + + # Add batch dimension if needed: [channels, samples] -> [batch, channels, samples] + if waveform.dim() == 2: + waveform = waveform.unsqueeze(0) + + if waveform.shape[1] == 1: + # make sure it is stereo + waveform = waveform.repeat(1, 2, 1) + + # Convert waveform to mel spectrogram using AudioProcessor + mel_spectrogram = self.audio_processor.waveform_to_mel( + waveform, waveform_sample_rate=sample_rate + ) + mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype) + + # Encode mel spectrogram to latents + latents = self.pipeline.audio_vae.encode( + mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype) + ).latent_dist.sample() + + if audio_num_frames is None: + audio_num_frames = latents.shape[2] # (latents is [B, C, T, F]) + + packed_latents = self.pipeline._pack_audio_latents( + latents, + # patch_size=self.pipeline.transformer.config.audio_patch_size, + # patch_size_t=self.pipeline.transformer.config.audio_patch_size_t, + ) # [B, L, C * M] + if output_tensor is None: + output_tensor = packed_latents + else: + output_tensor = torch.cat([output_tensor, packed_latents], dim=0) + + # normalize latents, opposite of (latents * latents_std) + latents_mean + latents_mean = self.pipeline.audio_vae.latents_mean + latents_std = self.pipeline.audio_vae.latents_std + output_tensor = (output_tensor - latents_mean) / latents_std + return output_tensor + + def pad_embeds(self, embeds: PromptEmbeds): + # ltx-2 connector requires 1024 tokens for good results. Any smaller and it degrades. + target_length = 1024 + current_length = embeds.text_embeds.shape[1] + if current_length < target_length: + pad_length = target_length - current_length + pad_tensor = torch.zeros( + (embeds.text_embeds.shape[0], pad_length, embeds.text_embeds.shape[2]), + device=embeds.text_embeds.device, + dtype=embeds.text_embeds.dtype, + ) + embeds.text_embeds = torch.cat([pad_tensor, embeds.text_embeds], dim=1) + if embeds.attention_mask is not None: + pad_mask = torch.zeros( + (embeds.attention_mask.shape[0], pad_length), + device=embeds.attention_mask.device, + dtype=embeds.attention_mask.dtype, + ) + embeds.attention_mask = torch.cat( + [pad_mask, embeds.attention_mask], dim=1 + ) + return embeds + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2 + text_embeddings = self.pad_embeds(text_embeddings) + + batch_size, C, latent_num_frames, latent_height, latent_width = ( + latent_model_input.shape + ) + + video_timestep = timestep.clone() + + # i2v from first frame + if batch.dataset_config.do_i2v and batch.num_frames > 1: + # check to see if we had it cached + if batch.first_frame_latents is not None: + init_latents = batch.first_frame_latents.to( + self.device_torch, dtype=self.torch_dtype + ) + else: + # extract the first frame and encode it + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + # first frame doesnt have time dim, add it back + init_latents = self.encode_images( + first_frames, device=self.device_torch, dtype=self.torch_dtype + ) + + # expand the latents to match video frames + init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1) + mask_shape = ( + batch_size, + 1, + latent_num_frames, + latent_height, + latent_width, + ) + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros( + mask_shape, device=self.device_torch, dtype=self.torch_dtype + ) + conditioning_mask[:, :, 0] = 1.0 + + # use conditioning mask to replace latents + latent_model_input = ( + init_latents * conditioning_mask + + latent_model_input * (1 - conditioning_mask) + ) + + packed_conditioning_mask = self.pipeline._pack_latents( + conditioning_mask, + patch_size=self.pipeline.transformer_spatial_patch_size, + patch_size_t=self.pipeline.transformer_temporal_patch_size, + ) + + # set video timestep + video_timestep = timestep.unsqueeze(-1) * (1 - packed_conditioning_mask) + + frame_rate = batch.dataset_config.fps + # check frame dimension + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + packed_latents = self.pipeline._pack_latents( + latent_model_input, + patch_size=self.pipeline.transformer_spatial_patch_size, + patch_size_t=self.pipeline.transformer_temporal_patch_size, + ) + + if batch.audio_latents is not None or batch.audio_tensor is not None: + if batch.audio_latents is not None: + # we have audio latents cached + raw_audio_latents = batch.audio_latents.to( + self.device_torch, dtype=self.torch_dtype + ) + else: + # we have audio waveforms to encode + # use audio from the batch if available + raw_audio_latents = self.encode_audio(batch.audio_data) + + audio_num_frames = raw_audio_latents.shape[1] + # add the audio targets to the batch for loss calculation later + audio_noise = torch.randn_like(raw_audio_latents) + batch.audio_target = (audio_noise - raw_audio_latents).detach() + audio_latents = self.add_noise( + raw_audio_latents, + audio_noise, + timestep, + ).to(self.device_torch, dtype=self.torch_dtype) + else: + # no audio + num_mel_bins = self.pipeline.audio_vae.config.mel_bins + # latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.pipeline.audio_vae.config.latent_channels + ) + duration_s = batch.num_frames / frame_rate + audio_latents_per_second = ( + self.pipeline.audio_sampling_rate + / self.pipeline.audio_hop_length + / float(self.pipeline.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + audio_latents = self.pipeline.prepare_audio_latents( + batch_size, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=0.0, + dtype=torch.float32, + device=self.transformer.device, + generator=None, + latents=None, + ) + + if self.pipeline.connectors.device != self.transformer.device: + self.pipeline.connectors.to(self.transformer.device) + + # Padding side for default Gemma3-12B text encoder + tokenizer_padding_side = "left" + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") + ( + connector_prompt_embeds, + connector_audio_prompt_embeds, + connector_attention_mask, + ) = self.pipeline.connectors( + text_embeddings.text_embeds, + text_embeddings.attention_mask.to(self.transformer.dtype), + padding_side=tokenizer_padding_side, + ) + + # compute video and audio positional ids + video_coords = self.transformer.rope.prepare_video_coords( + packed_latents.shape[0], + latent_num_frames, + latent_height, + latent_width, + packed_latents.device, + fps=frame_rate, + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + # use_cross_timestep - Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + # calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + # `False` is the legacy LTX-2.0 behavior. + use_cross_timestep = self.ltx_version == "2.3" + + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=packed_latents, + audio_hidden_states=audio_latents.to(self.transformer.dtype), + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + sigma=timestep, # Used by LTX-2.3 + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=None, + return_dict=False, + ) + + # add audio latent to batch if we had audio + if batch.audio_target is not None: + batch.audio_pred = noise_pred_audio + + unpacked_output = self.pipeline._unpack_latents( + latents=noise_pred_video, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + patch_size=self.pipeline.transformer_spatial_patch_size, + patch_size_t=self.pipeline.transformer_temporal_patch_size, + ) + + return unpacked_output + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + device = self.device_torch + scale_factor = 8 + batch_size = len(prompt) + # Gemma expects left padding for chat-style prompts + self.tokenizer[0].padding_side = "left" + if self.tokenizer[0].pad_token is None: + self.tokenizer[0].pad_token = self.tokenizer[0].eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer[0]( + prompt, + # padding="max_length", + padding="longest", + max_length=1024, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder[0]( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to( + dtype=self.torch_dtype + ) # Pack to 3D + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(batch_size * 1, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, 1) + + pe = PromptEmbeds([prompt_embeds, None]) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer: LTX2VideoTransformer3DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "ltx2" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["transformer_blocks"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + new_sd = convert_lora_diffusers_to_original(new_sd, version=self.ltx_version) + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + state_dict = convert_lora_original_to_diffusers( + state_dict, version=self.ltx_version + ) + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + +class LTX23Model(LTX2Model): + arch = "ltx2.3" + ltx_version = "2.3" + ltx_te_path = base_te_path diff --git a/ai-toolkit/extensions_built_in/diffusion_models/nucleus_image/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/nucleus_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5798f5d810b539d56a561c74548807f9f4131be --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/nucleus_image/__init__.py @@ -0,0 +1 @@ +from .nucleus_image_model import NucleusImageModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/nucleus_image/nucleus_image_model.py b/ai-toolkit/extensions_built_in/diffusion_models/nucleus_image/nucleus_image_model.py new file mode 100644 index 0000000000000000000000000000000000000000..572fda831439f64e6950c671b8eae681f3aac90f --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/nucleus_image/nucleus_image_model.py @@ -0,0 +1,420 @@ +import itertools +import os +from typing import List, Optional + +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager + +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor +import torch.nn.functional as F + +try: + from diffusers import NucleusMoEImagePipeline, NucleusMoEImageTransformer2DModel, AutoencoderKLQwenImage + from diffusers.models.transformers.transformer_nucleusmoe_image import SwiGLUExperts +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + + +class NucleusImageModel(BaseModel): + arch = "nucleus_image" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["NucleusMoEImageTransformer2DModel"] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 * 2 # 16 for the VAE, 2 for patch size + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Nucleus model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = NucleusMoEImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + # handle versions of pytorch that don't have grouped mm, by disabling it in the SwiGLUExperts + if not hasattr(torch.nn.functional, "grouped_mm"): + for m in transformer.modules(): + if isinstance(m, SwiGLUExperts): + m.use_grouped_mm = False + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + ], + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = Qwen3VLProcessor.from_pretrained( + base_model_path, subfolder="processor", torch_dtype=dtype + ) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKLQwenImage.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ).to(self.device_torch, dtype=dtype) + + self.noise_scheduler = NucleusImageModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: NucleusMoEImagePipeline = NucleusMoEImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + processor=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.processor] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = NucleusImageModel.get_train_scheduler() + + pipeline: NucleusMoEImagePipeline = NucleusMoEImagePipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + processor=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + images = torch.stack(image_list).to(device, dtype=dtype) + # it uses wan vae, so add dim for frame count + + images = images.unsqueeze(2) + latents = self.vae.encode(images).latent_dist.sample() + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + + latents = (latents - latents_mean) * latents_std + latents = latents.to(device, dtype=dtype) + + latents = latents.squeeze(2) # remove the frame count dimension + + return latents + + def generate_single_image( + self, + pipeline: NucleusMoEImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: AdvancedPromptEmbeds, + unconditional_embeds: AdvancedPromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + if self.model_config.layer_offloading: + parameters_and_buffers = itertools.chain(self.model.parameters(), self.model.buffers()) + next(parameters_and_buffers).to(self.device_torch) + + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds[0].unsqueeze(0), + prompt_embeds_mask=conditional_embeds.attention_mask[0].unsqueeze(0), + negative_prompt_embeds=unconditional_embeds.text_embeds[0].unsqueeze(0), + negative_prompt_embeds_mask=unconditional_embeds.attention_mask[0].unsqueeze(0), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: AdvancedPromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + if self.model_config.layer_offloading: + parameters_and_buffers = itertools.chain(self.model.parameters(), self.model.buffers()) + next(parameters_and_buffers).to(self.device_torch) + + with torch.no_grad(): + patch_size = self.pipeline.transformer.config.patch_size + + img_shape = (1, latent_model_input.shape[2] // patch_size, latent_model_input.shape[3] // patch_size) + img_shapes = [ + img_shape for _ in range(latent_model_input.shape[0]) + ] + latent_height = latent_model_input.shape[2] + latent_width = latent_model_input.shape[3] + + pixel_height = latent_model_input.shape[2] * self.pipeline.vae_scale_factor + pixel_width = latent_model_input.shape[3] * self.pipeline.vae_scale_factor + + latent_model_input = self.pipeline._pack_latents( + latents=latent_model_input, + batch_size=latent_model_input.shape[0], + num_channels_latents=self.pipeline.transformer.config.in_channels // 4, + height=latent_height, + width=latent_width, + patch_size=patch_size, + ) + + pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + encoder_hidden_states=torch.stack(text_embeddings.text_embeds, dim=0), + encoder_hidden_states_mask=torch.stack(text_embeddings.attention_mask, dim=0), + img_shapes=img_shapes, + return_dict=False, + )[0] + + # invert it + pred = -pred + + pred = self.pipeline._unpack_latents( + latents=pred, + height=pixel_height, + width=pixel_width, + patch_size=patch_size, + vae_scale_factor=self.pipeline.vae_scale_factor + ) + + pred = pred.squeeze(2) # remove frame dimension [B, C, 1, H, W] -> [B, C, H, W] + + return pred + + def get_prompt_embeds(self, prompt: str) -> AdvancedPromptEmbeds: + if self.pipeline.text_encoder.device == torch.device("cpu"): + self.pipeline.text_encoder.to(self.device_torch) + + if isinstance(prompt, str): + prompt = [prompt] + + return_index = self.pipeline.default_return_index + device = self.device_torch + + formatted = [self.pipeline._format_prompt(p) for p in prompt] + + inputs = self.pipeline.processor( + text=formatted, + padding="longest", + pad_to_multiple_of=8, + max_length=1024, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ).to(device=device) + + prompt_embeds_mask = inputs.attention_mask + + outputs = self.pipeline.text_encoder(**inputs, use_cache=False, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.hidden_states[return_index] + prompt_embeds = prompt_embeds.to(dtype=self.pipeline.text_encoder.dtype, device=device) + + pe = AdvancedPromptEmbeds( + text_embeds=[x for x in prompt_embeds], + attention_mask=[x for x in prompt_embeds_mask], + ) + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer: NucleusMoEImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return self.arch + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["transformer_blocks"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77ce910de160bafaed0565ac01acafea59add020 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/__init__.py @@ -0,0 +1,363 @@ +import os +from typing import TYPE_CHECKING, List, Optional + +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from diffusers import AutoencoderKL +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype +from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline +from .src.models.transformers import OmniGen2Transformer2DModel +from .src.models.transformers.repo import OmniGen2RotaryPosEmbed +from .src.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler, +) +from PIL import Image +from transformers import ( + CLIPProcessor, + Qwen2_5_VLForConditionalGeneration, +) +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = {"num_train_timesteps": 1000} + +BASE_MODEL_PATH = "OmniGen2/OmniGen2" + + +class OmniGen2Model(BaseModel): + arch = "omnigen2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["OmniGen2Transformer2DModel"] + self._control_latent = None + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + # HiDream-ai/HiDream-I1-Full + self.print_and_status_update("Loading OmniGen2 model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + extras_path = self.model_config.extras_name_or_path + + scheduler = OmniGen2Model.get_train_scheduler() + + self.print_and_status_update("Loading Qwen2.5 VL") + processor = CLIPProcessor.from_pretrained( + extras_path, subfolder="processor", use_fast=True + ) + + mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( + extras_path, subfolder="mllm", torch_dtype=torch.bfloat16 + ) + mllm.to(self.device_torch, dtype=dtype) + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Qwen2.5 VL model") + quantization_type = get_qtype(self.model_config.qtype_te) + quantize(mllm, weights=quantization_type) + freeze(mllm) + + if self.low_vram: + # unload it for now + mllm.to("cpu") + + flush() + + self.print_and_status_update("Loading transformer") + + transformer = OmniGen2Transformer2DModel.from_pretrained( + model_path, subfolder="transformer", torch_dtype=torch.bfloat16 + ) + + if not self.low_vram: + transformer.to(self.device_torch, dtype=dtype) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantization_type = get_qtype(self.model_config.qtype) + quantize(transformer, weights=quantization_type) + freeze(transformer) + + if self.low_vram: + # unload it for now + transformer.to("cpu") + + flush() + + self.print_and_status_update("Loading vae") + + vae = AutoencoderKL.from_pretrained( + extras_path, subfolder="vae", torch_dtype=torch.bfloat16 + ).to(self.device_torch, dtype=dtype) + + flush() + self.print_and_status_update("Loading Qwen2.5 VLProcessor") + + flush() + + if self.low_vram: + self.print_and_status_update("Moving everything to device") + # move it all back + transformer.to(self.device_torch, dtype=dtype) + vae.to(self.device_torch, dtype=dtype) + mllm.to(self.device_torch, dtype=dtype) + + # set to eval mode + # transformer.eval() + vae.eval() + mllm.eval() + mllm.requires_grad_(False) + + pipe: OmniGen2Pipeline = OmniGen2Pipeline( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor, + ) + + flush() + + text_encoder_list = [mllm] + tokenizer_list = [processor] + + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder_list # list of text encoders + self.tokenizer = tokenizer_list # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + + self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + transformer.config.axes_dim_rope, + transformer.config.axes_lens, + theta=10000, + ) + + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = OmniFlowMatchEuler( + dynamic_time_shift=True, num_train_timesteps=1000 + ) + + pipeline: OmniGen2Pipeline = OmniGen2Pipeline( + transformer=self.model, + vae=self.vae, + scheduler=scheduler, + mllm=self.text_encoder[0], + processor=self.tokenizer[0], + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: OmniGen2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + input_images = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + input_images = [control_img] + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attention_mask=conditional_embeds.attention_mask, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_attention_mask=unconditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + text_guidance_scale=gen_config.guidance_scale, + image_guidance_scale=1.0, # reference image guidance scale. Add this for controls + latents=gen_config.latents, + align_res=False, + generator=generator, + input_images=input_images, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + try: + timestep = timestep.expand(latent_model_input.shape[0]).to( + latent_model_input.dtype + ) + except Exception as e: + pass + + timesteps = timestep / 1000 # convert to 0 to 1 scale + # timestep for model starts at 0 instead of 1. So we need to reverse them + timestep = 1 - timesteps + model_pred = self.model( + latent_model_input, + timestep, + text_embeddings.text_embeds, + self.freqs_cis, + text_embeddings.attention_mask, + ref_image_hidden_states=self._control_latent, + ) + + return model_pred + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + # reset the control latent + self._control_latent = None + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + self.vae.to(self.device_torch) + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to( + self.vae_device_torch, dtype=self.torch_dtype + ) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + # todo, we may not need to do this, check + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if ( + control_tensor.shape[2] != target_h + or control_tensor.shape[3] != target_w + ): + control_tensor = F.interpolate( + control_tensor, size=(target_h, target_w), mode="bilinear" + ) + + control_latent = self.encode_images(control_tensor).to( + latents.device, latents.dtype + ) + self._control_latent = [ + [x.squeeze(0)] + for x in torch.chunk(control_latent, control_latent.shape[0], dim=0) + ] + + return latents.detach() + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt] + self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) + max_sequence_length = 256 + prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( + prompt=prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + max_sequence_length=max_sequence_length, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + + def get_te_has_grad(self): + # assume no one wants to finetune 4 text encoders. + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the transformer + transformer: OmniGen2Transformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + # return (noise - batch.latents).detach() + return (batch.latents - noise).detach() + + def get_transformer_block_names(self) -> Optional[List[str]]: + # omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers. + # lets do all but image refiner until we add it + if self.model_config.model_kwargs.get("use_image_refiner", False): + return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"] + return ["noise_refiner", "context_refiner", "layers"] + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "omnigen2" diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..1f713c75f1f8a164440f45974f9540393a483a79 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/attention_processor.py @@ -0,0 +1,357 @@ +""" +OmniGen2 Attention Processor Module + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import warnings +import math +from typing import Optional, Tuple, Dict, Any + +import torch +import torch.nn.functional as F +from einops import repeat + +from ..utils.import_utils import is_flash_attn_available + +if is_flash_attn_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +else: + warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance") + + +from diffusers.models.attention_processor import Attention +from .embeddings import apply_rotary_emb + + +class OmniGen2AttnProcessorFlash2Varlen: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not is_flash_attn_available(): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. " + "Please install flash_attn." + ) + + def _upad_input( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + num_heads: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + """ + Unpad the input tensors for flash attention. + + Args: + query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + attention_mask: Attention mask tensor of shape (batch_size, seq_len) + query_length: Length of the query sequence + num_heads: Number of attention heads + + Returns: + Tuple containing: + - Unpadded query tensor + - Unpadded key tensor + - Unpadded value tensor + - Query indices + - Tuple of cumulative sequence lengths for query and key + - Tuple of maximum sequence lengths for query and key + """ + def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Helper function to get unpadding data from attention mask.""" + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + # Unpad key and value layers + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + + # Handle different query length cases + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # Unpad input for flash attention + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + # Handle different number of heads + if kv_heads < attn.heads: + key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + + # Apply flash attention + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + + # Pad output and apply final transformations + hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class OmniGen2AttnProcessor: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor is optimized for PyTorch 2.0 and implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + + Raises: + ImportError: If PyTorch version is less than 2.0 + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " + "Please upgrade PyTorch to version 2.0 or later." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..5282f2defef551b70276a24ae16995cfe515679c --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/embeddings.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + + +from diffusers.models.activations import get_activation + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.zeros_(self.linear_1.bias) + nn.init.normal_(self.linear_2.weight, std=0.02) + nn.init.zeros_(self.linear_2.bias) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen and CogView4 + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..157de1d42af9ddb8162077aeae0ef52acdd792f8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/__init__.py @@ -0,0 +1,3 @@ +from .transformer_omnigen2 import OmniGen2Transformer2DModel + +__all__ = ["OmniGen2Transformer2DModel"] \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py new file mode 100644 index 0000000000000000000000000000000000000000..13739d3a596d196411b85af19a55b34eb566b0cf --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/block_lumina2.py @@ -0,0 +1,218 @@ + +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from diffusers.models.embeddings import Timesteps +from ..embeddings import TimestepEmbedding + +from ...utils.import_utils import is_flash_attn_available, is_triton_available + +if is_triton_available(): + from ...ops.triton.layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") + +if is_flash_attn_available(): + from flash_attn.ops.activations import swiglu +else: + from .components import swiglu + warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance") + +# try: +# from flash_attn.ops.activations import swiglu as fused_swiglu +# FUSEDSWIGLU_AVALIBLE = True +# except ImportError: + +# FUSEDSWIGLU_AVALIBLE = False +# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + text_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(text_feat_dim, eps=norm_eps), + nn.Linear(text_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py new file mode 100644 index 0000000000000000000000000000000000000000..5e654b8c4d7228609817cea7c25036728d6f588d --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/components.py @@ -0,0 +1,4 @@ +import torch.nn.functional as F + +def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py new file mode 100644 index 0000000000000000000000000000000000000000..ea565bf8d3b2f29ce560e031046e4c7944333b1a --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/repo.py @@ -0,0 +1,135 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from einops import repeat +from diffusers.models.embeddings import get_1d_rotary_pos_embed + +class OmniGen2RotaryPosEmbed(nn.Module): + def __init__(self, theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis(axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int], + theta: int) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] + + max_seq_len = int(max(seq_lengths)) + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_seq_len = int(cap_seq_len) + seq_len = int(seq_len) + # add text position ids + position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + # add image position ids + + row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() + col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() + col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len: seq_len, 0] = pe_shift + position_ids[i, pe_shift_len: seq_len, 1] = row_ids + position_ids[i, pe_shift_len: seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + ref_img_freqs_cis = torch.zeros( + batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + img_freqs_cis = torch.zeros( + batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): + cap_seq_len = int(cap_seq_len) + sum_ref_img_len = int(sum(ref_img_len)) + img_len = int(img_len) + seq_len = int(seq_len) + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, :sum_ref_img_len] = freqs_cis[i, cap_seq_len:cap_seq_len + sum_ref_img_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum_ref_img_len:cap_seq_len + sum_ref_img_len + img_len] + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + ) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7fef689174756a8e2e9065c17249625b3fd434 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/models/transformers/transformer_omnigen2.py @@ -0,0 +1,621 @@ +import warnings +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from einops import rearrange + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin + +from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor +from .repo import OmniGen2RotaryPosEmbed +from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding + +from ...utils.import_utils import is_triton_available, is_flash_attn_available + +if is_triton_available(): + from ...ops.triton.layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + +logger = logging.get_logger(__name__) + + +class OmniGen2TransformerBlock(nn.Module): + """ + Transformer block for OmniGen2 model. + + This block implements a transformer layer with: + - Multi-head attention with flash attention + - Feed-forward network with SwiGLU activation + - RMS normalization + - Optional modulation for conditional generation + + Args: + dim: Dimension of the input and output tensors + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for the feed-forward network dimension + norm_eps: Epsilon value for normalization layers + modulation: Whether to use modulation for conditional generation + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + try: + processor = OmniGen2AttnProcessorFlash2Varlen() + except ImportError: + processor = OmniGen2AttnProcessor() + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=processor, + ) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the transformer block. + + Uses Xavier uniform initialization for linear layers and zero initialization for biases. + """ + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the transformer block. + + Args: + hidden_states: Input hidden states tensor + attention_mask: Attention mask tensor + image_rotary_emb: Rotary embeddings for image tokens + temb: Optional timestep embedding tensor + + Returns: + torch.Tensor: Output hidden states after transformer block processing + """ + import time + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + OmniGen2 Transformer 2D Model. + + A transformer-based diffusion model for image generation with: + - Patch-based image processing + - Rotary position embeddings + - Multi-head attention + - Conditional generation support + + Args: + patch_size: Size of image patches + in_channels: Number of input channels + out_channels: Number of output channels (defaults to in_channels) + hidden_size: Size of hidden layers + num_layers: Number of transformer layers + num_refiner_layers: Number of refiner layers + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for feed-forward network dimension + norm_eps: Epsilon value for normalization layers + axes_dim_rope: Dimensions for rotary position embeddings + axes_lens: Lengths for rotary position embeddings + text_feat_dim: Dimension of text features + timestep_scale: Scale factor for timestep embeddings + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Omnigen2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + text_feat_dim: int = 1024, + timestep_scale: float = 1.0 + ) -> None: + """Initialize the OmniGen2 transformer model.""" + super().__init__() + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + self.out_channels = out_channels or in_channels + + # Initialize embeddings + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=text_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale + ) + + # Initialize transformer blocks + self.noise_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.ref_image_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.context_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels + ) + + # Add learnable embeddings to distinguish different images + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images + + self.gradient_checkpointing = False + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the model. + + Uses Xavier uniform initialization for linear layers. + """ + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb + ): + batch_size = len(hidden_states) + max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j] + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size) + batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # sequence of ref imgs to batch + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len] + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + # refine ref imgs separately + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb) + + # batch of ref imgs to sequence + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len] + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) + for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): + combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)] + combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + + if ref_image_hidden_states is not None: + ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states] + l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # ref image patch embeddings + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # image patch embeddings + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] + padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Condition, positional & patch embedding + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # 2. Context refinement + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # 3. Joint Transformer blocks + max_seq_len = int(max(seq_lengths)) + + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + encoder_seq_len = int(encoder_seq_len) + seq_len = int(seq_len) + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len] + + hidden_states = joint_hidden_states + + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + + p = self.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): + img_len = int(img_len) + seq_len = int(seq_len) + height, width = img_size + output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p)) + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d123304559857a3b9ab903cd4faab6029e436c --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/ops/triton/layer_norm.py @@ -0,0 +1,1257 @@ +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + zero_centered_weight=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + x_shape_og = x.shape + # Check for zero sequence length + if x.numel() == 0: + ctx.zero_seq_length = True + # Only save minimal required tensors for backward + # ctx.save_for_backward(weight, bias, weight1, bias1) + ctx.x_shape_og = x_shape_og + ctx.weight_shape = weight.shape + ctx.weight_dtype = weight.dtype + ctx.weight_device = weight.device + + ctx.has_bias = bias is not None + ctx.bias_shape = bias.shape if bias is not None else None + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.bias_device = bias.device if bias is not None else None + + ctx.has_weight1 = weight1 is not None + ctx.weight1_shape = weight1.shape if weight1 is not None else None + ctx.weight1_dtype = weight1.dtype if weight1 is not None else None + ctx.weight1_device = weight1.device if weight1 is not None else None + + ctx.has_bias1 = bias1 is not None + ctx.bias1_shape = bias1.shape if bias1 is not None else None + ctx.bias1_dtype = bias1.dtype if bias1 is not None else None + ctx.bias1_device = bias1.device if bias1 is not None else None + + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.dropout_p = dropout_p + + # Handle output tensors with correct dtype + y = x # Preserve input tensor properties + y1 = torch.empty_like(x) if x1 is not None else None + + # Only create residual_out if prenorm is True + residual_out = torch.empty(x.shape, + dtype=torch.float32 if residual_in_fp32 else x.dtype, + device=x.device) if prenorm else None + + # Handle dropout masks + dropout_mask = None + dropout_mask1 = None + if return_dropout_mask: + dropout_mask = torch.empty_like(x, dtype=torch.uint8) + if x1 is not None: + dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) + + # Return based on configuration + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ((y, dropout_mask, dropout_mask1) if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1)) + else: + return ((y, y1, dropout_mask, dropout_mask1) if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1)) + + ctx.zero_seq_length = False + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + if ctx.zero_seq_length: + return ( + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), + torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), + torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, + torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, + torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out, + residual_out + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out, + residual_out + ) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.zero_centered_weight = zero_centered_weight + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..ec66dcdb9bcd70f79940e5a088ef4a984d6a3722 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/image_processor.py @@ -0,0 +1,266 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist +from diffusers.configuration_utils import register_to_config + +class OmniGen2ImageProcessor(VaeImageProcessor): + """ + Image processor for PixArt image resize and crop. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + resample: str = "lanczos", + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + self.max_pixels = max_pixels + self.max_side_length = max_side_length + + def get_new_height_width( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + ) -> Tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + if max_side_length is None: + max_side_length = self.max_side_length + + if max_pixels is None: + max_pixels = self.max_pixels + + ratio = 1.0 + if max_side_length is not None: + if height > width: + max_side_length_ratio = max_side_length / height + else: + max_side_length_ratio = max_side_length / width + + cur_pixels = height * width + max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image + + new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor + return new_height, new_width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", # "default", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Args: + image (`PipelineImageInput`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of + supported formats. + height (`int`, *optional*): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + if not is_valid_image_imagelist(image): + raise ValueError( + f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" + ) + if not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + if self.config.do_resize: + height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length) + image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + + image = self.numpy_to_pt(image) + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] + # don't need any preprocess if the image is latents + if channel == self.config.vae_latent_channels: + return image + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if do_normalize and image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + if do_normalize: + image = self.normalize(image) + + if self.config.do_binarize: + image = self.binarize(image) + + return image \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py new file mode 100644 index 0000000000000000000000000000000000000000..a48548f16e62349df6c9ce133f67d594537f7324 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py @@ -0,0 +1,729 @@ +""" +OmniGen2 Diffusion Pipeline + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import math + +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F + +from transformers import Qwen2_5_VLForConditionalGeneration + +from diffusers.models.autoencoders import AutoencoderKL +from ...models.transformers import OmniGen2Transformer2DModel +from ...models.transformers.repo import OmniGen2RotaryPosEmbed +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from dataclasses import dataclass + +import PIL.Image + +from diffusers.utils import BaseOutput + +from ....src.pipelines.image_processor import OmniGen2ImageProcessor + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class FMPipelineOutput(BaseOutput): + """ + Output class for OmniGen2 pipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGen2Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using OmniGen2. + + This pipeline implements a text-to-image generation model that uses: + - Qwen2.5-VL for text encoding + - A custom transformer architecture for image generation + - VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler for noise scheduling + + Args: + transformer (OmniGen2Transformer2DModel): The transformer model for image generation. + vae (AutoencoderKL): The VAE model for image encoding/decoding. + scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling. + text_encoder (Qwen2_5_VLModel): The text encoder model. + tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + + def __init__( + self, + transformer: OmniGen2Transformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen2_5_VLForConditionalGeneration, + processor, + ) -> None: + """ + Initialize the OmniGen2 pipeline. + + Args: + transformer: The transformer model for image generation. + vae: The VAE model for image encoding/decoding. + scheduler: The scheduler for noise scheduling. + text_encoder: The text encoder model. + tokenizer: The tokenizer for text processing. + """ + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def prepare_image( + self, + images: Union[List[PIL.Image.Image], PIL.Image.Image], + batch_size: int, + num_images_per_prompt: int, + max_pixels: int, + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_prompt: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + if batch_size == 1: + images = [images] + latents = [] + for i, img in enumerate(images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + for _ in range(num_images_per_prompt): + latents.append(ref_latents) + + return latents + + def _get_qwen2_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get prompt embeddings from the Qwen2 text encoder. + + Args: + prompt: The prompt or list of prompts to encode. + device: The device to place the embeddings on. If None, uses the pipeline's device. + max_sequence_length: Maximum sequence length for tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The prompt embeddings tensor + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + # text_inputs = self.processor.tokenizer( + # prompt, + # padding="max_length", + # max_length=max_sequence_length, + # truncation=True, + # return_tensors="pt", + # ) + text_inputs = self.processor.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + # untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + # if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + # removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because Gemma can only handle sequences up to" + # f" {max_sequence_length} tokens: {removed_text}" + # ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.mllm( + text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-1] + + if self.mllm is not None: + dtype = self.mllm.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def _apply_chat_template(self, prompt: str): + prompt = [ + { + "role": "system", + "content": "You are a helpful assistant that generates high-quality images based on user instructions.", + }, + {"role": "user", "content": prompt}, + ] + prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False) + return prompt + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + + if prompt is not None: + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [self._apply_chat_template(_prompt) for _prompt in prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt] + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=negative_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: int = 1024 * 1024, + max_input_image_side_length: int = 1024, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.text_guidance_scale > 1.0, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + + dtype = self.vae.dtype + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + max_pixels=max_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16 + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + verbose=verbose, + step_func=step_func, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + else: + return FMPipelineOutput(images=image) + + def processing( + self, + latents, + ref_latents, + prompt_embeds, + freqs_cis, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + verbose, + step_func=None + ): + batch_size = latents.shape[0] + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + num_tokens=latents.shape[-2] * latents.shape[-1] + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_pred = self.predict( + t=t, + latents=latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if image_guidance_scale != 1: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + else: + model_pred_uncond = torch.zeros_like(model_pred) + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + # ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + **optional_kwargs + ) + return model_pred \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..43d88402fa34b1f95eec2327ad9028aeb2b90f2f --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2_chat.py @@ -0,0 +1,830 @@ +""" +OmniGen2 Diffusion Pipeline + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import math + +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F + +from transformers import Qwen2_5_VLForConditionalGeneration + +from diffusers.models.autoencoders import AutoencoderKL +from ...models.transformers import OmniGen2Transformer2DModel +from ...models.transformers.repo import OmniGen2RotaryPosEmbed +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from dataclasses import dataclass + +import PIL.Image + +from diffusers.utils import BaseOutput + +from src.pipelines.image_processor import OmniGen2ImageProcessor + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class OmniGen2PipelineOutput(BaseOutput): + """ + Output class for OmniGen2 pipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + text: str + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGen2ChatPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using OmniGen2. + + This pipeline implements a text-to-image generation model that uses: + - Qwen2.5-VL for text encoding + - A custom transformer architecture for image generation + - VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler for noise scheduling + + Args: + transformer (OmniGen2Transformer2DModel): The transformer model for image generation. + vae (AutoencoderKL): The VAE model for image encoding/decoding. + scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling. + text_encoder (Qwen2_5_VLModel): The text encoder model. + tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + def __init__( + self, + transformer: OmniGen2Transformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen2_5_VLForConditionalGeneration, + processor, + ) -> None: + """ + Initialize the OmniGen2 pipeline. + + Args: + transformer: The transformer model for image generation. + vae: The VAE model for image encoding/decoding. + scheduler: The scheduler for noise scheduling. + text_encoder: The text encoder model. + tokenizer: The tokenizer for text processing. + """ + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def prepare_image( + self, + images: Union[List[PIL.Image.Image], PIL.Image.Image], + batch_size: int, + num_images_per_prompt: int, + max_pixels: int, + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_prompt: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + if batch_size == 1: + images = [images] + latents = [] + for i, img in enumerate(images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + for _ in range(num_images_per_prompt): + latents.append(ref_latents) + + return latents + + def _apply_chat_template(self, prompt: str, images: List = None): + if images is not None: + prompt = "".join( + [ + f": <|vision_start|><|image_pad|><|vision_end|>" + for i in range(1, len(images) + 1) + ] + ) + prompt + prompt = f"<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + return prompt + + def _get_qwen2_prompt_embeds( + self, + prompt: Union[str, List[str]], + input_images = None, + device: Optional[torch.device] = None, + use_only_text_hidden_states: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get prompt embeddings from the Qwen2 text encoder. + + Args: + prompt: The prompt or list of prompts to encode. + device: The device to place the embeddings on. If None, uses the pipeline's device. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The prompt embeddings tensor + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + inputs = self.processor( + text=prompt, + images=input_images, + videos=None, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(device) + + prompt_embeds = self.mllm( + **inputs, + output_hidden_states=True, + ).hidden_states[-1] + + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask + if use_only_text_hidden_states: + mask = text_input_ids != self.mllm.config.image_token_id + mask = mask & text_mask + mask = mask.bool() + + text_l = mask.sum(dim=-1) + max_l = text_l.max() + text_batch_size = prompt_embeds.size(0) + new_prompt_embeds = torch.zeros((text_batch_size, max_l, prompt_embeds.size(-1)), device=prompt_embeds.device, dtype=prompt_embeds.dtype) + new_text_mask = torch.zeros((text_batch_size, max_l), dtype=text_mask.dtype, device=text_mask.device) + for i in range(text_batch_size): + new_prompt_embeds[i, :text_l[i]] = prompt_embeds[i, mask[i]] + new_text_mask[i, :text_l[i]] = 1 + + prompt_embeds = new_prompt_embeds + text_mask = new_text_mask + + prompt_embeds = prompt_embeds.to(dtype=self.mllm.dtype, device=device) + return prompt_embeds, text_mask + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + input_images: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + use_text_encoder_penultimate_layer_feats: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=prompt, + input_images=input_images, + device=device, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + negative_prompt_embeds, negative_prompt_attention_mask = None, None + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt] + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=negative_prompt, + device=device, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + def prepare_inputs_for_text_generation(self, prompts, input_images, device): + if isinstance(prompts, str): + prompts = [prompts] + + ori_padding_side = self.processor.tokenizer.padding_side + self.processor.tokenizer.padding_side = "left" + inputs = self.processor( + text=prompts, + images=input_images, + videos=None, + padding=True, + return_tensors="pt", + ).to(device) + self.processor.tokenizer.padding_side = ori_padding_side + return inputs + + def generate_text(self, prompt, input_images): + inputs = self.prepare_inputs_for_text_generation( + prompt, input_images, self.mllm.device + ) + generated_ids = self.mllm.generate( + **inputs, + tokenizer=self.processor.tokenizer, + max_new_tokens=256, + stop_strings=["<|im_end|>", "<|img|>", "<|endoftext|>"], + ) # stop_words=[151643, 151645, 151665] + generated_ids_trimmed = [ + out_ids[len(in_ids) :] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_texts = self.processor.batch_decode( + generated_ids_trimmed, + # skip_special_tokens=True, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + return output_texts + + def generate_image( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + use_text_encoder_penultimate_layer_feats: bool = False, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: int = 1024 * 1024, + max_input_image_side_length: int = 1024, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input promptb + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + input_images, + self.text_guidance_scale > 1.0, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats + ) + + dtype = self.vae.dtype + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + max_pixels=max_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16 + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + verbose=verbose, + step_func=step_func, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + return image + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + use_text_encoder_penultimate_layer_feats: bool = False, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + max_pixels: Optional[int] = 1024 * 1024, + max_input_image_side_length: int = 1024, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + assert isinstance(prompt, str), "prompt must be a string since chat mode only support one prompt per turn" + + # input_images = self.preprocess_images(input_images, max_input_image_size) + prompt = self._apply_chat_template(prompt, input_images) + generated_text = self.generate_text(prompt, input_images)[0] + + images = None + if generated_text.startswith("<|img|>"): + #TODO: reuse the hidden state when generate text instead of re-generating + prompt = prompt + generated_text.split("<|img|>")[0] + images = self.generate_image( + prompt=prompt, + negative_prompt=negative_prompt, + use_text_encoder_penultimate_layer_feats=use_text_encoder_penultimate_layer_feats, + max_sequence_length=max_sequence_length, + input_images=input_images, + num_images_per_prompt=num_images_per_prompt, + height=height, + width=width, + max_pixels=max_pixels, + max_input_image_side_length=max_input_image_side_length, + align_res=align_res, + num_inference_steps=num_inference_steps, + text_guidance_scale=text_guidance_scale, + image_guidance_scale=image_guidance_scale, + cfg_range=cfg_range, + timesteps=timesteps, + generator=generator, + latents=latents, + return_dict=False, + verbose=verbose, + step_func=step_func, + ) + + generated_text = generated_text.replace("<|im_end|>", "") + if not return_dict: + return generated_text, images + else: + return OmniGen2PipelineOutput(text=generated_text, images=images) + + def processing( + self, + latents, + ref_latents, + prompt_embeds, + freqs_cis, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + verbose, + step_func=None + ): + batch_size = latents.shape[0] + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + num_tokens=latents.shape[-2] * latents.shape[-1] + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_pred = self.predict( + t=t, + latents=latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if image_guidance_scale != 1: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + else: + model_pred_uncond = torch.zeros_like(model_pred) + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + **optional_kwargs + ) + return model_pred diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de31ff4e8627a93377e7c3c071162f6b395da688 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/pipelines/pipeline_utils.py @@ -0,0 +1,62 @@ +import torch + + +def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): + """ Get pipeline embeds for prompts bigger than the maxlength of the pipe + :param pipeline: + :param prompt: + :param negative_prompt: + :param device: + :return: + """ + max_length = pipeline.tokenizer.model_max_length + + # simple way to determine length of tokens + # count_prompt = len(prompt.split(" ")) + # count_negative_prompt = len(negative_prompt.split(" ")) + + # create the tensor based on which prompt is longer + # if count_prompt >= count_negative_prompt: + input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device) + # input_ids = pipeline.tokenizer(prompt, padding="max_length", + # max_length=pipeline.tokenizer.model_max_length, + # truncation=True, + # return_tensors="pt",).input_ids.to(device) + shape_max_length = input_ids.shape[-1] + + if negative_prompt is not None: + negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length", + max_length=shape_max_length, return_tensors="pt").input_ids.to(device) + + # else: + # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device) + # shape_max_length = negative_ids.shape[-1] + # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length", + # max_length=shape_max_length).input_ids.to(device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, max_length): + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + if negative_prompt is not None: + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + concat_embeds = torch.cat(concat_embeds, dim=1) + + if negative_prompt is not None: + neg_embeds = torch.cat(neg_embeds, dim=1) + else: + neg_embeds = None + + return concat_embeds, neg_embeds diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py new file mode 100644 index 0000000000000000000000000000000000000000..1c85e9300ff3176fc978d0a79343e0706e4a4110 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,1052 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +if is_scipy_available(): + import scipy.stats + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to 1.0): + The shift value for the timestep schedule for flow matching. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: str = 'zero', + dynamic_time_shift: bool = True + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + # if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + # raise ValueError( + # f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + # ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + num_tokens: Optional[int] = None + ): + if timesteps is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1] + if self.config.dynamic_time_shift and num_tokens is not None: + m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2 + timesteps = timesteps / (m - m * timesteps + timesteps) + + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + sigmas = torch.cat([1 - timesteps, torch.zeros(1, device=timesteps.device)]) + + self.sigmas = sigmas + self.timesteps = timesteps + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 - sigma + sigma_t = sigma + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample + sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..91513476f8430bca650ce0f38504343822b7b49a --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,229 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + dynamic_time_shift: bool = True + ): + timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1] + + self.timesteps = timesteps + + self._step_index = None + self._begin_index = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self._timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[float]] = None, + num_tokens: Optional[int] = None + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if timesteps is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1] + if self.config.dynamic_time_shift and num_tokens is not None: + m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2 + timesteps = timesteps / (m - m * timesteps + timesteps) + + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)]) + + self.timesteps = timesteps + self._timesteps = _timesteps + self._step_index = None + self._begin_index = None + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + t = self._timesteps[self.step_index] + t_next = self._timesteps[self.step_index + 1] + + prev_sample = sample + (t_next - t) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d9026a637f4b8ab015b841e5a6313313db1dc8e8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/img_util.py @@ -0,0 +1,31 @@ +from typing import List + +from PIL import Image + +import torch +from torchvision.transforms.functional import to_pil_image + +def resize_image(image, max_pixels, img_scale_num): + width, height = image.size + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) # do not upscale input image + + new_height, new_width = int(height * ratio) // img_scale_num * img_scale_num, int(width * ratio) // img_scale_num * img_scale_num + + image = image.resize((new_width, new_height), resample=Image.BICUBIC) + return image + +def create_collage(images: List[torch.Tensor]) -> Image.Image: + """Create a horizontal collage from a list of images.""" + max_height = max(img.shape[-2] for img in images) + total_width = sum(img.shape[-1] for img in images) + canvas = torch.zeros((3, max_height, total_width), device=images[0].device) + + current_x = 0 + for img in images: + h, w = img.shape[-2:] + canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5 + current_x += w + + return to_pil_image(canvas) \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc946d77782b66ccb8d6402b2e521a9e5de4e81c --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/omnigen2/src/utils/import_utils.py @@ -0,0 +1,46 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.util +import sys + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + +_triton_available, _triton_version = _is_package_available("triton") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") + +def is_triton_available(): + return _triton_available + +def is_flash_attn_available(): + return _flash_attn_available \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/README.md b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/README.md new file mode 100644 index 0000000000000000000000000000000000000000..19f33b1e6b1dbecbf20b2e6d1cb51e7a8a0af210 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/README.md @@ -0,0 +1,53 @@ +# PRXPixel (Photoroom PRX-7B, pixel-space text-to-image) + +ai-toolkit integration for [`Photoroom/prxpixel-t2i`](https://huggingface.co/Photoroom/prxpixel-t2i), +a ~7B pixel-space diffusion transformer. + +It is implemented **from scratch** so ai-toolkit does not depend on the (still +unmerged) diffusers PR [huggingface/diffusers#13928](https://github.com/huggingface/diffusers/pull/13928): +the transformer is vendored in [src/transformer_prx.py](src/transformer_prx.py) +and a minimal preview sampler lives in [src/pipeline.py](src/pipeline.py). + +## What makes this model unusual + +PRXPixel differs from a typical latent flow-matching model in three ways, each +handled in [prx_pixel_t2i.py](prx_pixel_t2i.py): + +| Property | What it means | How it's handled | +|---|---|---| +| **Pixel space** | No VAE — the transformer denoises raw RGB (`in_channels=3`, `patch_size=16`) | A `FakeVAE` (identity, scaling 1) so encode/decode are no-ops; "latents" are the image in `[-1, 1]` | +| **x-prediction** | The model predicts the clean image `x0`, not the flow velocity | `get_noise_prediction` returns `x0`; `get_loss_target` is the clean latents. The `x0 → velocity` conversion only happens at sampling time | +| **noise_scale = 2.0** | Trains/samples from `randn * 2.0`, not unit noise | `get_latent_noise_from_latents` scales the training noise; the pipeline scales the starting noise | + +Text is encoded by the Qwen3-VL text tower (`Qwen3VLTextModel`, hidden size +2048 → the transformer's `context_in_dim`), padded to 256 tokens. + +The x-prediction objective follows *"Back to Basics: Let Denoising Generative +Models Denoise"* (https://arxiv.org/abs/2511.13720). + +## Architecture (released checkpoint) + +`depth=24`, `hidden_size=3584`, `num_heads=28`, `mlp_ratio=3.5`, +`in_channels=3`, `patch_size=16`, `context_in_dim=2048`, `bottleneck_size=768`, +`axes_dim=[64, 64]`, `resolution_embeds=True`, flow-matching scheduler with +`shift=3.0`. + +## Train it + +```yaml +model: + arch: "prx_pixel" + name_or_path: "/path/to/prxpixel-t2i" # diffusers folder: transformer/, + # text_encoder/, tokenizer/, scheduler/ + quantize: true # optional: qfloat8 the transformer + quantize_te: true # optional: qfloat8 the Qwen3-VL text encoder +train: + gradient_checkpointing: true +sample: + guidance_scale: 5.0 + sample_steps: 28 +``` + +Datasets bucket to multiples of 16px (`vae_scale_factor * patch_size`). +See [../example_model/README.md](../example_model/README.md) for the generic +lifecycle, registration and LoRA conventions. diff --git a/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23cf44774141e9abb490835a90173000967a2f48 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/__init__.py @@ -0,0 +1,3 @@ +from .prx_pixel_t2i import PRXPixelT2IModel + +__all__ = ["PRXPixelT2IModel"] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/prx_pixel_t2i.py b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/prx_pixel_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..fb44ded38c064ab0dc0a676a4518ea939d99a534 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/prx_pixel_t2i.py @@ -0,0 +1,345 @@ +"""PRXPixelT2IModel -- the Photoroom pixel-space PRX-7B text-to-image model +(https://huggingface.co/Photoroom/prxpixel-t2i) wired into ai-toolkit. + +This is implemented from scratch (the diffusers support is an unmerged PR, +https://github.com/huggingface/diffusers/pull/13928) so ai-toolkit does not +depend on it: the transformer architecture is vendored in ``src/transformer_prx.py`` +and a minimal sampler lives in ``src/pipeline.py``. + +What makes PRXPixel unusual (and what each override below is doing about it): + + - **Pixel space, no VAE.** The transformer denoises raw RGB directly + (``in_channels=3``, ``patch_size=16``). We use a ``FakeVAE`` (identity, + scaling_factor=1) so BaseModel's encode_images/decode_latents become no-ops + and the "latents" everywhere in the toolkit are just the image in [-1, 1]. + Same trick as ``../chroma/chroma_radiance_model.py`` and + ``extensions/z_image_pixel``. + + - **x-prediction.** The model predicts the CLEAN image x0, not the + flow-matching velocity. ai-toolkit's MSE compares ``get_noise_prediction`` + against ``get_loss_target``; we set BOTH to the x0 space (prediction = the + model's x0 output, target = the clean latents), which is PRXPixel's native + training objective ("Back to Basics: Let Denoising Generative Models + Denoise", https://arxiv.org/abs/2511.13720). The x0->velocity conversion + only happens at sampling time, inside the pipeline. + + - **noise_scale = 2.0.** PRXPixel trains with a non-unit initial-noise std, + so the noise mixed into the latents (training) and the starting noise + (sampling) are ``randn * noise_scale``. We override + ``get_latent_noise_from_latents`` for the training side; the pipeline + handles the sampling side. + + - **Qwen3-VL text tower.** Prompts are encoded by ``Qwen3VLTextModel`` + (hidden size 2048 -> the transformer's ``context_in_dim``). We keep the + per-token ``last_hidden_state`` plus an attention mask. + +See ../example_model/README.md for the generic lifecycle/registration guide. +""" + +import os +from typing import List, Optional + +import torch +import yaml + +from transformers import AutoTokenizer, Qwen3VLTextModel +from optimum.quanto import freeze + +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.models.FakeVAE import FakeVAE +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.train_tools import apply_noise_offset +from toolkit.util.quantize import quantize, get_qtype, quantize_model + +from .src.transformer_prx import PRXTransformer2DModel +from .src.pipeline import PRXPixelPipeline + + +# Flow-matching scheduler config, matching the released model's +# scheduler/scheduler_config.json (shift 3.0, 1000 train timesteps). +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + +# Number of text tokens PRXPixel was trained with (the Qwen tokenizer's own +# model_max_length is far larger). Matches PRX_PIXEL_DEFAULT_MAX_TOKENS. +PROMPT_MAX_TOKENS = 256 +# Initial-noise std PRXPixel trains/samples with. +NOISE_SCALE = 2.0 + + +class PRXPixelT2IModel(BaseModel): + # ``model.arch: "prx_pixel"`` in the training config YAML selects this class. + arch = "prx_pixel" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + # Matched against type(module).__name__ to place LoRA layers. + self.target_lora_modules = ["PRXTransformer2DModel"] + + # used by our overrides below + self.patch_size = 16 + self.vae_scale_factor = 1 # pixel space: no VAE downsampling + self.max_text_length = PROMPT_MAX_TOKENS + self.noise_scale = NOISE_SCALE + + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # pixels must be divisible by patch_size (no VAE downsample): 1 * 16 = 16 + return self.vae_scale_factor * self.patch_size + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading PRXPixel model") + # Expected diffusers-style layout under name_or_path: + # transformer/ (PRXTransformer2DModel) text_encoder/ (Qwen3VLTextModel) + # tokenizer/ scheduler/ + model_path = self.model_config.name_or_path + + # --- transformer (vendored PRXTransformer2DModel) --- + self.print_and_status_update("Loading transformer") + # from_pretrained reads config.json (bottleneck_size, resolution_embeds, + # in_channels=3, ...) and the safetensors in one shot. + transformer = PRXTransformer2DModel.from_pretrained( + model_path, subfolder="transformer", torch_dtype=dtype + ) + transformer.to(dtype=dtype) + flush() + + if self.model_config.quantize: + self.print_and_status_update("Quantizing transformer") + quantize_model(self, transformer) + flush() + + if self.model_config.low_vram: + transformer.to("cpu") + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + # --- text encoder + tokenizer (Qwen3-VL text tower) --- + self.print_and_status_update("Loading text encoder") + tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + text_encoder = Qwen3VLTextModel.from_pretrained( + model_path, subfolder="text_encoder", torch_dtype=dtype + ) + text_encoder.to(self.te_device_torch) + text_encoder.eval() + text_encoder.requires_grad_(False) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing text encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + # --- "VAE": identity, since PRXPixel is pixel space --- + self.print_and_status_update("Preparing pixel-space VAE (identity)") + vae = FakeVAE(scaling_factor=1.0) + vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + + # --- scheduler + store everything --- + self.noise_scheduler = PRXPixelT2IModel.get_train_scheduler() + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.model = transformer # aliased as self.transformer / self.unet + self.pipeline = PRXPixelPipeline(self) + self.print_and_status_update("Model Loaded") + + # ------------------------------------------------------------------ + # Sampling (training previews) + # ------------------------------------------------------------------ + def get_generation_pipeline(self): + return PRXPixelPipeline(self) + + def generate_single_image( + self, + pipeline: PRXPixelPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # snap requested size to the model's divisibility (16px) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + img = pipeline( + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + )[0] + return img + + # ------------------------------------------------------------------ + # Training hooks + # ------------------------------------------------------------------ + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0..1000 scale, 1000 = pure noise + text_embeddings: PromptEmbeds, + **kwargs, + ): + """Forward pass. Returns the model's predicted CLEAN image x0. + + PRXPixel is an x-prediction model, so the raw transformer output is the + prediction we compare against ``get_loss_target`` (the clean latents). + No velocity conversion happens here -- that is a sampling-time concern. + """ + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + # toolkit timestep (0..1000) -> PRX flow time in [0, 1]. + t01 = timestep.to(self.device_torch, dtype=torch.float32) / 1000.0 + + feats = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + mask = getattr(text_embeddings, "attention_mask", None) + if mask is not None: + mask = mask.to(self.device_torch) + + x0_pred = self.model( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=t01, + encoder_hidden_states=feats, + attention_mask=mask, + return_dict=False, + )[0] + return x0_pred + + def get_prompt_embeds(self, prompt) -> PromptEmbeds: + """Encode prompt text with the Qwen3-VL text tower. + + Returns a PromptEmbeds whose ``text_embeds`` is (B, L, 2048) and whose + ``attention_mask`` is (B, L). Prompts are padded to a fixed length + (PROMPT_MAX_TOKENS) so cached embeds are concatenatable for CFG. + """ + if isinstance(prompt, str): + prompt = [prompt] + + if self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device_torch) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.max_text_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(self.text_encoder.device) + attention_mask = text_inputs.attention_mask.to(self.text_encoder.device) + + with torch.no_grad(): + output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + embeds = output["last_hidden_state"].to(self.torch_dtype) + + pe = PromptEmbeds(embeds) + # keep the mask as bool/long; it must not be dtype-cast with the embeds + pe.attention_mask = attention_mask.bool() + return pe + + def get_loss_target(self, *args, **kwargs): + """x-prediction target: the clean image (x0). + + PRXPixel predicts x0 directly, so the MSE target is simply the clean + latents (the pixel image in [-1, 1]), not the flow velocity. + """ + batch = kwargs.get("batch") + return batch.latents.detach() + + def get_latent_noise_from_latents(self, latents: torch.Tensor, noise_offset=0.0): + """Noise for the forward flow, scaled by the model's noise_scale. + + PRXPixel trains with a non-unit initial-noise std: the noise mixed into + the latents is ``randn * noise_scale``. The scheduler's add_noise then + forms ``x_t = (1 - t) * clean + t * noise``. + """ + noise = torch.randn_like(latents) * self.noise_scale + if noise_offset is not None and noise_offset != 0.0: + noise = apply_noise_offset(noise, noise_offset) + return noise + + def condition_noisy_latents(self, latents: torch.Tensor, batch): + # plain text-to-image: nothing to inject + return latents + + # ------------------------------------------------------------------ + # Saving / bookkeeping + # ------------------------------------------------------------------ + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + """Full fine-tune save: write the transformer back in diffusers layout.""" + transformer: PRXTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + with open(os.path.join(output_path, "aitk_meta.yaml"), "w") as f: + yaml.dump(meta, f) + + def get_base_model_version(self): + return "prx_pixel" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["blocks"] + + def convert_lora_weights_before_save(self, state_dict): + return { + k.replace("transformer.", "diffusion_model."): v + for k, v in state_dict.items() + } + + def convert_lora_weights_before_load(self, state_dict): + return { + k.replace("diffusion_model.", "transformer."): v + for k, v in state_dict.items() + } diff --git a/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa894520a0ac0647a50e31623325336699a86325 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/__init__.py @@ -0,0 +1,4 @@ +# Everything diffusers does NOT ship (until PR #13928 lands) lives in src/: +# the vendored PRX transformer architecture and a minimal pixel-space sampler. +from .transformer_prx import PRXTransformer2DModel +from .pipeline import PRXPixelPipeline diff --git a/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a940468cc0aedd33b0d1630829bcbbc5e8173104 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/pipeline.py @@ -0,0 +1,154 @@ +"""Minimal preview sampler for the pixel-space PRX (PRXPixel) model. + +ai-toolkit only uses this pipeline to render preview/sample images during +training (BaseModel.generate_images -> PRXPixelT2IModel.generate_single_image). +It does not need to be a diffusers DiffusionPipeline, and ai-toolkit always +encodes the prompts itself, so the pipeline only ever receives already-encoded +``PromptEmbeds`` (text features + attention mask), never raw text. + +PRXPixel specifics this sampler bakes in (see ../prx_pixel_t2i.py for the why): + - **pixel space**: there is no VAE. The model's "latents" are the RGB image + itself in [-1, 1]; the final "decode" is just a clamp + uint8 cast. + - **x-prediction**: the transformer predicts the clean image x0, not the + flow-matching velocity. Each step converts it to velocity + ``v = (x_t - x0) / t`` (t clamped for stability) before the scheduler step, + exactly like the diffusers PRXPixelPipeline. + - **noise_scale**: PRXPixel trains with a non-unit initial-noise std, so the + starting noise is ``randn * noise_scale`` (2.0 for the released model). + - **CFG** is applied on the x0 prediction (before the velocity conversion). +""" + +from typing import List, Optional + +import torch +from PIL import Image +from diffusers.utils.torch_utils import randn_tensor + + +# Minimum normalized timestep used when converting an x0 prediction to a flow +# velocity ``v = (x_t - x0) / t``. Mirrors the 0.05 clamp in the diffusers +# PRXPixelPipeline; without it the division blows up as t -> 0. +X_PRED_T_MIN = 0.05 + + +class PRXPixelPipeline: + """Lightweight pixel-space flow-matching sampler used for training previews.""" + + def __init__(self, model): + # ``model`` is the PRXPixelT2IModel (a BaseModel subclass), giving us + # access to model.transformer, model.decode_latents, device/dtype, the + # scheduler factory and the noise scale. + self.model = model + + @property + def device(self): + return self.model.device_torch + + def to(self, *args, **kwargs): + # BaseModel.generate_images may call pipeline.to(device); we manage + # devices through the model itself, so this is a no-op. + return self + + def set_progress_bar_config(self, **kwargs): + # called by the sampler harness (inside a try/except, so optional) + pass + + def _embeds_and_mask(self, embeds, device, dtype): + """Pull (features, attention_mask) out of a PromptEmbeds onto device/dtype. + + ``text_embeds`` is (B, L, D). The mask is (B, L) bool (1 = real token); + it is kept as long for the transformer's boolean masking and never cast + to the model dtype. + """ + feats = embeds.text_embeds.to(device, dtype=dtype) + mask = getattr(embeds, "attention_mask", None) + if mask is not None: + mask = mask.to(device) + return feats, mask + + @torch.no_grad() + def __call__( + self, + conditional_embeds, # PromptEmbeds: .text_embeds (B,L,D) + .attention_mask + unconditional_embeds, # PromptEmbeds or None (negative prompt) + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 28, + guidance_scale: float = 5.0, + latents: Optional[torch.Tensor] = None, # pre-made noise, usually None + generator: Optional[ + torch.Generator + ] = None, # seeded RNG for reproducible samples + **kwargs, + ) -> List[Image.Image]: + model = self.model + device = model.device_torch + dtype = model.torch_dtype + transformer = model.transformer + + # Always sample with a FRESH scheduler -- the training scheduler is + # stateful and mutating it mid-training would corrupt the train step. + scheduler = model.get_train_scheduler() + scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = scheduler.timesteps # 1000 -> 0 scale + + do_cfg = unconditional_embeds is not None and guidance_scale != 1.0 + + # 1. starting noise -- pixel space, so channels = transformer.in_channels (3), + # spatial size = the requested pixels (no VAE downsample). Scaled by the + # model's noise_scale to match the learned flow-matching trajectory. + if latents is None: + shape = (1, transformer.in_channels, height, width) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=torch.float32 + ) + latents = latents * model.noise_scale + latents = latents.to(device, dtype=torch.float32) + + # 2. text features + masks + cond_feats, cond_mask = self._embeds_and_mask(conditional_embeds, device, dtype) + if do_cfg: + uncond_feats, uncond_mask = self._embeds_and_mask( + unconditional_embeds, device, dtype + ) + + # 3. denoising loop + for t in timesteps: + # scheduler timesteps are 0-1000; the transformer wants [0, 1] + t01 = (t / 1000.0).to(device).float().expand(latents.shape[0]) + + x0_cond = transformer( + hidden_states=latents.to(dtype), + timestep=t01, + encoder_hidden_states=cond_feats, + attention_mask=cond_mask, + return_dict=False, + )[0] + if do_cfg: + x0_uncond = transformer( + hidden_states=latents.to(dtype), + timestep=t01, + encoder_hidden_states=uncond_feats, + attention_mask=uncond_mask, + return_dict=False, + )[0] + # classifier-free guidance is applied on the x0 prediction + x0 = x0_uncond + guidance_scale * (x0_cond - x0_uncond) + else: + x0 = x0_cond + + # convert the x0 (clean-image) prediction to the flow velocity the + # scheduler consumes: v = (x_t - x0) / t, t clamped for stability. + t_x = torch.clamp(t01.to(torch.float32), min=X_PRED_T_MIN).view(-1, 1, 1, 1) + v = (latents - x0.to(torch.float32)) / t_x + + latents = scheduler.step(v, t, latents, return_dict=False)[0] + + # 4. pixel space: the denoised latents ARE the image in [-1, 1]. + # decode_latents is an identity (FakeVAE) but we route through it so + # any future latent normalization stays in one place. + images = model.decode_latents(latents, device=device, dtype=torch.float32) + images = images.float().clamp(-1.0, 1.0) + images = ((images + 1.0) * 127.5).round().to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy() + return [Image.fromarray(arr) for arr in images] diff --git a/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/transformer_prx.py b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/transformer_prx.py new file mode 100644 index 0000000000000000000000000000000000000000..e83d8a1add7148da2e95359b773d2de47b4f8e42 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/prx_pixel_t2i/src/transformer_prx.py @@ -0,0 +1,881 @@ +# PRXTransformer2DModel -- vendored, self-contained copy of the Photoroom PRX +# transformer so ai-toolkit does not depend on an unmerged diffusers PR +# (https://github.com/huggingface/diffusers/pull/13928). +# +# This is the architecture as it appears on the ``prx-pixel`` branch, with two +# small edits so it runs against a stock installed diffusers: +# - the relative ``from ...`` / ``from ..`` imports are made absolute +# (``from diffusers...``) -- every building block used here ships in stock +# diffusers already, only the assembled PRX model is missing. +# - ``maybe_adjust_dtype_for_device`` (added in a newer diffusers) is inlined +# below so we don't depend on its presence. +# +# The pixel-space PRX-7B variant (Photoroom/prxpixel-t2i) sets +# ``in_channels=3`` (raw RGB, no VAE), ``bottleneck_size`` (two-layer patch +# projection) and ``resolution_embeds=True`` (resolution-aware modulation). +# +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import logging +from diffusers.models.attention import AttentionMixin, AttentionModuleMixin +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.embeddings import get_timestep_embedding +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm + + +logger = logging.get_logger(__name__) + + +def maybe_adjust_dtype_for_device( + dtype: torch.dtype, device: torch.device +) -> torch.dtype: + r""" + MPS does not implement float64; fall back to float32 there. On every other device the requested dtype is returned + unchanged. Inlined from newer diffusers so this file works on older installs. + """ + if ( + device is not None + and getattr(device, "type", None) == "mps" + and dtype == torch.float64 + ): + return torch.float32 + return dtype + + +def get_image_ids( + batch_size: int, height: int, width: int, patch_size: int, device: torch.device +) -> torch.Tensor: + r""" + Generates 2D patch coordinate indices for a batch of images. + + Args: + batch_size (`int`): + Number of images in the batch. + height (`int`): + Height of the input images (in pixels). + width (`int`): + Width of the input images (in pixels). + patch_size (`int`): + Size of the square patches that the image is divided into. + device (`torch.device`): + The device on which to create the tensor. + + Returns: + `torch.Tensor`: + Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the + image grid. + """ + + img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] + return ( + img_ids.reshape((height // patch_size) * (width // patch_size), 2) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + + +def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + r""" + Applies rotary positional embeddings (RoPE) to a query tensor. + + Args: + xq (`torch.Tensor`): + Input tensor of shape `(..., dim)` representing the queries. + freqs_cis (`torch.Tensor`): + Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs. + + Returns: + `torch.Tensor`: + Tensor of the same shape as `xq` with rotary embeddings applied. + """ + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading + freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq) + + +class PRXAttnProcessor2_0: + r""" + Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention + backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: "PRXAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply PRX attention using PRXAttention module. + + Args: + attn: PRXAttention module containing projection layers + hidden_states: Image tokens [B, L_img, D] + encoder_hidden_states: Text tokens [B, L_txt, D] + attention_mask: Boolean mask for text tokens [B, L_txt] + image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2] + """ + + if encoder_hidden_states is None: + raise ValueError( + "PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." + ) + + # Project image tokens to Q, K, V + img_qkv = attn.img_qkv_proj(hidden_states) + B, L_img, _ = img_qkv.shape + img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim) + img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D] + img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] + + # Apply QK normalization to image tokens + img_q = attn.norm_q(img_q) + img_k = attn.norm_k(img_k) + + # Project text tokens to K, V + txt_kv = attn.txt_kv_proj(encoder_hidden_states) + B, L_txt, _ = txt_kv.shape + txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim) + txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D] + txt_k, txt_v = txt_kv[0], txt_kv[1] + + # Apply K normalization to text tokens + txt_k = attn.norm_added_k(txt_k) + + # Apply RoPE to image queries and keys + if image_rotary_emb is not None: + img_q = apply_rope(img_q, image_rotary_emb) + img_k = apply_rope(img_k, image_rotary_emb) + + # Concatenate text and image keys/values + k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D] + v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D] + + # Build attention mask if provided + attn_mask_tensor = None + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + + if attention_mask.dim() != 2: + raise ValueError( + f"Unsupported attention_mask shape: {attention_mask.shape}" + ) + if attention_mask.shape[-1] != l_txt: + raise ValueError( + f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + ) + + device = img_q.device + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + attention_mask = attention_mask.to(device=device, dtype=torch.bool) + joint_mask = torch.cat([attention_mask, ones_img], dim=-1) + attn_mask_tensor = joint_mask[:, None, None, :].expand( + -1, attn.heads, l_img, -1 + ) + + # Apply attention using dispatch_attention_fn for backend support + # Reshape to match dispatch_attention_fn expectations: [B, L, H, D] + query = img_q.transpose(1, 2) # [B, L_img, H, D] + key = k.transpose(1, 2) # [B, L_txt + L_img, H, D] + value = v.transpose(1, 2) # [B, L_txt + L_img, H, D] + + attn_output = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask_tensor, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape from [B, L_img, H, D] to [B, L_img, H*D] + batch_size, seq_len, num_heads, head_dim = attn_output.shape + attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + + return attn_output + + +class PRXAttention(nn.Module, AttentionModuleMixin): + r""" + PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for + PRX's architecture. + """ + + _default_processor_cls = PRXAttnProcessor2_0 + _available_processors = [PRXAttnProcessor2_0] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + bias: bool = False, + out_bias: bool = False, + eps: float = 1e-6, + processor=None, + ): + super().__init__() + + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.query_dim = query_dim + + self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias) + + self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + + self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias) + self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(0.0)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + **kwargs, + ) + + +# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class PRXEmbedND(nn.Module): + r""" + N-dimensional rotary positional embedding. + + This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding + dimension. The embeddings are combined and returned as a single tensor + + Args: + dim (int): + Base embedding dimension (must be even). + theta (int): + Scaling factor that controls the frequency spectrum of the rotary embeddings. + axes_dim (list[int]): + list of embedding dimensions for each axis (each must be even). + """ + + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + + dtype = maybe_adjust_dtype_for_device(torch.float64, pos.device) + + scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2) + out = out.reshape(*out.shape[:-1], 2, 2) + return out.float() + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [ + self.rope(ids[:, :, i], self.axes_dim[i], self.theta) + for i in range(n_axes) + ], + dim=-3, + ) + return emb.unsqueeze(1) + + +class MLPEmbedder(nn.Module): + r""" + A simple 2-layer MLP used for embedding inputs. + + Args: + in_dim (`int`): + Dimensionality of the input features. + hidden_dim (`int`): + Dimensionality of the hidden and output embedding space. + + Returns: + `torch.Tensor`: + Tensor of shape `(..., hidden_dim)` containing the embedded representations. + """ + + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class PRXResolutionEmbedder(nn.Module): + r""" + Embeds the spatial resolution `(height, width)` of the latent into a vector that is added to the timestep + embedding, so the model can condition its modulation on the generation resolution. + + A sinusoidal embedding of dimension 128 is built for the height and the width separately and concatenated into a + 256-dim vector, which is then projected to `hidden_size` by a 2-layer MLP. This matches the `"vec"` mode of the + resolution-aware conditioning used during PRX-7B training. + + Args: + hidden_size (`int`): + Dimension of the output embedding (must match the timestep embedding dimension). + max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for the sinusoidal resolution embedding. + """ + + def __init__(self, hidden_size: int, max_period: int = 10000): + super().__init__() + self.max_period = max_period + self.mlp = MLPEmbedder(in_dim=256, hidden_dim=hidden_size) + + def forward( + self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + h_emb = get_timestep_embedding( + timesteps=height, + embedding_dim=128, + max_period=self.max_period, + scale=1.0, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + ) + w_emb = get_timestep_embedding( + timesteps=width, + embedding_dim=128, + max_period=self.max_period, + scale=1.0, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + ) + hw_emb = torch.cat([h_emb, w_emb], dim=-1).to(dtype) + return self.mlp(hw_emb) + + +class Modulation(nn.Module): + r""" + Modulation network that generates scale, shift, and gating parameters. + + Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into + two tuples `(shift, scale, gate)`. + + Args: + dim (`int`): + Dimensionality of the input vector. The output will have `6 * dim` features internally. + + Returns: + ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): + Two tuples `(shift, scale, gate)`. + """ + + def __init__(self, dim: int): + super().__init__() + self.lin = nn.Linear(dim, 6 * dim, bias=True) + nn.init.constant_(self.lin.weight, 0) + nn.init.constant_(self.lin.bias, 0) + + def forward( + self, vec: torch.Tensor + ) -> tuple[ + tuple[torch.Tensor, torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) + return tuple(out[:3]), tuple(out[3:]) + + +class PRXBlock(nn.Module): + r""" + Multimodal transformer block with text-image cross-attention, modulation, and MLP. + + Args: + hidden_size (`int`): + Dimension of the hidden representations. + num_heads (`int`): + Number of attention heads. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Expansion ratio for the hidden dimension inside the MLP. + qk_scale (`float`, *optional*): + Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.hidden_size = hidden_size + + # Pre-attention normalization for image tokens + self.img_pre_norm = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6 + ) + + # PRXAttention module with built-in projections and norms + self.attention = PRXAttention( + query_dim=hidden_size, + heads=num_heads, + dim_head=self.head_dim, + bias=False, + out_bias=False, + eps=1e-6, + processor=PRXAttnProcessor2_0(), + ) + + # mlp + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6 + ) + self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) + self.mlp_act = nn.GELU(approximate="tanh") + + self.modulation = Modulation(hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: dict[str, Any], + ) -> torch.Tensor: + r""" + Runs modulation-gated cross-attention and MLP, with residual connections. + + Args: + hidden_states (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + encoder_hidden_states (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + temb (`torch.Tensor`): + Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or + broadcastable). + image_rotary_emb (`torch.Tensor`): + Rotary positional embeddings applied inside attention. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. + + Returns: + `torch.Tensor`: + Updated image tokens of shape `(B, L_img, hidden_size)`. + """ + + mod_attn, mod_mlp = self.modulation(temb) + attn_shift, attn_scale, attn_gate = mod_attn + mlp_shift, mlp_scale, mlp_gate = mod_mlp + + hidden_states_mod = (1 + attn_scale) * self.img_pre_norm( + hidden_states + ) + attn_shift + + attn_out = self.attention( + hidden_states=hidden_states_mod, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + attn_gate * attn_out + + x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift + hidden_states = hidden_states + mlp_gate * ( + self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) + ) + return hidden_states + + +class FinalLayer(nn.Module): + r""" + Final projection layer with adaptive LayerNorm modulation. + + This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level + outputs. + + Args: + hidden_size (`int`): + Dimensionality of the input tokens. + patch_size (`int`): + Size of the square image patches. + out_channels (`int`): + Number of output channels per pixel (e.g. RGB = 3). + """ + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: + r""" + Flattens an image tensor into a sequence of non-overlapping patches. + + Args: + img (`torch.Tensor`): + Input image tensor of shape `(B, C, H, W)`. + patch_size (`int`): + Size of each square patch. Must evenly divide both `H` and `W`. + + Returns: + `torch.Tensor`: + Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W + // patch_size)` is the number of patches. + """ + b, c, h, w = img.shape + p = patch_size + + # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions + img = img.reshape(b, c, h // p, p, w // p, p) + + # Permute to (B, H//p, W//p, C, p, p) using einsum + # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width + img = torch.einsum("nchpwq->nhwcpq", img) + + # Flatten to (B, L, C * p * p) + img = img.reshape(b, -1, c * p * p) + return img + + +def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: + r""" + Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). + + Args: + seq (`torch.Tensor`): + Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // + patch_size)`. + patch_size (`int`): + Size of each square patch. + shape (`tuple` or `torch.Tensor`): + The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as + height and width. + + Returns: + `torch.Tensor`: + Reconstructed image tensor of shape `(B, C, H, W)`. + """ + if isinstance(shape, tuple): + h, w = shape[-2:] + elif isinstance(shape, torch.Tensor): + h, w = (int(shape[0]), int(shape[1])) + else: + raise NotImplementedError(f"shape type {type(shape)} not supported") + + b, l, d = seq.shape + p = patch_size + c = d // (p * p) + + # Reshape back to grid structure: (B, H//p, W//p, C, p, p) + seq = seq.reshape(b, h // p, w // p, c, p, p) + + # Permute back to image layout: (B, C, H//p, p, W//p, p) + # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width + seq = torch.einsum("nhwcpq->nchpwq", seq) + + # Final reshape to (B, C, H, W) + seq = seq.reshape(b, c, h, w) + return seq + + +class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): + r""" + Transformer-based 2D model for text to image generation. + + Args: + in_channels (`int`, *optional*, defaults to 16): + Number of input channels in the latent image. The pixel-space PRX uses `3` (raw RGB, no VAE). + patch_size (`int`, *optional*, defaults to 2): + Size of the square patches used to flatten the input image. + context_in_dim (`int`, *optional*, defaults to 2304): + Dimensionality of the text conditioning input. + hidden_size (`int`, *optional*, defaults to 1792): + Dimension of the hidden representation. + mlp_ratio (`float`, *optional*, defaults to 3.5): + Expansion ratio for the hidden dimension inside MLP blocks. + num_heads (`int`, *optional*, defaults to 28): + Number of attention heads. + depth (`int`, *optional*, defaults to 16): + Number of transformer blocks. + axes_dim (`list[int]`, *optional*): + list of dimensions for each positional embedding axis. Defaults to `[32, 32]`. + theta (`int`, *optional*, defaults to 10000): + Frequency scaling factor for rotary embeddings. + time_factor (`float`, *optional*, defaults to 1000.0): + Scaling factor applied in timestep embeddings. + time_max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for timestep embeddings. + bottleneck_size (`int`, *optional*): + If set, the image patch projection (`img_in`) uses a two-layer bottleneck (`patch_dim -> bottleneck_size -> + hidden_size`) instead of a single linear layer. Used by the pixel-space PRX-7B variant where the patch + dimension is large. + resolution_embeds (`bool`, *optional*, defaults to `False`): + Whether to condition the timestep modulation on the latent resolution `(H, W)` via a + `PRXResolutionEmbedder`. Used by the PRX-7B variant. + """ + + config_name = "config.json" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + patch_size: int = 2, + context_in_dim: int = 2304, + hidden_size: int = 1792, + mlp_ratio: float = 3.5, + num_heads: int = 28, + depth: int = 16, + axes_dim: list = None, + theta: int = 10000, + time_factor: float = 1000.0, + time_max_period: int = 10000, + bottleneck_size: Optional[int] = None, + resolution_embeds: bool = False, + ): + super().__init__() + + if axes_dim is None: + axes_dim = [32, 32] + + # Store parameters directly + self.in_channels = in_channels + self.patch_size = patch_size + self.out_channels = self.in_channels * self.patch_size**2 + + self.time_factor = time_factor + self.time_max_period = time_max_period + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + + pe_dim = hidden_size // num_heads + + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + patch_dim = self.in_channels * self.patch_size**2 + if bottleneck_size is not None: + # Two-layer bottleneck projection (used by pixel-space PRX where the patch dimension is large). + self.img_in = nn.Sequential( + nn.Linear(patch_dim, bottleneck_size, bias=True), + nn.Linear(bottleneck_size, self.hidden_size, bias=True), + ) + else: + self.img_in = nn.Linear(patch_dim, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) + + self.resolution_embedder = ( + PRXResolutionEmbedder(self.hidden_size, max_period=time_max_period) + if resolution_embeds + else None + ) + + self.blocks = nn.ModuleList( + [ + PRXBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=mlp_ratio, + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + def _compute_timestep_embedding( + self, timestep: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + return self.time_in( + get_timestep_embedding( + timesteps=timestep, + embedding_dim=256, + max_period=self.time_max_period, + scale=self.time_factor, + flip_sin_to_cos=True, # Match original cos, sin order + downscale_freq_shift=0.0, + ).to(dtype) + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput: + r""" + Forward pass of the PRXTransformer2DModel. + + The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of + transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. + + Args: + hidden_states (`torch.Tensor`): + Input latent image tensor of shape `(B, C, H, W)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(B,)` or `(1,)` on a `[0, 1]` scale, used for temporal conditioning. + encoder_hidden_states (`torch.Tensor`): + Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. + + Returns: + `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple `(sample,)` with the predicted clean + image of shape `(B, C, H, W)`. + """ + # Process text conditioning + txt = self.txt_in(encoder_hidden_states) + + # Convert image to sequence and embed + img = img2seq(hidden_states, self.patch_size) + img = self.img_in(img) + + # Generate positional embeddings + bs, _, h, w = hidden_states.shape + img_ids = get_image_ids( + bs, h, w, patch_size=self.patch_size, device=hidden_states.device + ) + pe = self.pe_embedder(img_ids) + + # Compute time embedding + vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) + + # Add resolution conditioning (PRX-7B "vec" mode): embed the latent (H, W) and add it to the timestep vector + # so every block's modulation is resolution-aware. + if self.resolution_embedder is not None: + height = torch.full( + (bs,), h, device=hidden_states.device, dtype=torch.float32 + ) + width = torch.full( + (bs,), w, device=hidden_states.device, dtype=torch.float32 + ) + vec = vec + self.resolution_embedder(height, width, dtype=vec.dtype) + + # Apply transformer blocks + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img = self._gradient_checkpointing_func( + block.__call__, + img, + txt, + vec, + pe, + attention_mask, + ) + else: + img = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=pe, + attention_mask=attention_mask, + ) + + # Final layer and convert back to image + img = self.final_layer(img, vec) + output = seq2img(img, self.patch_size, hidden_states.shape) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff1797bfa7e5a51df9ad796e6772b39b225a2a8 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/__init__.py @@ -0,0 +1,3 @@ +from .qwen_image import QwenImageModel +from .qwen_image_edit import QwenImageEditModel +from .qwen_image_edit_plus import QwenImageEditPlusModel diff --git a/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image.py new file mode 100644 index 0000000000000000000000000000000000000000..648f22b4e7d4a1e360b9bd6cc84e23a037f0a10b --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -0,0 +1,484 @@ +import os +from typing import TYPE_CHECKING, List, Optional + +import torch +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file + +from diffusers import ( + QwenImagePipeline, + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import ( + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, + Qwen2VLProcessor, +) +from tqdm import tqdm + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 8192, + "max_shift": 0.9, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.02, + "stochastic_sampling": False, + "time_shift_type": "exponential", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} + + +class QwenImageModel(BaseModel): + arch = "qwen_image" + _qwen_image_keep_visual = False + _qwen_pipeline = QwenImagePipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 * 2 # 16 for the VAE, 2 for patch size + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Qwen Image model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + model_dtype = dtype + + if base_model_path.endswith(".safetensors"): + # use the repo for extras + base_model_path = "Qwen/Qwen-Image" + + self.print_and_status_update("Loading transformer") + + if model_path.endswith(".safetensors"): + # load the safetensors file + transformer = QwenImageTransformer2DModel.from_single_file( + model_path, + config="Qwen/Qwen-Image", + subfolder="transformer", + torch_dtype=model_dtype, + ) + transformer.to(model_dtype) + + else: + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = QwenImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = Qwen2Tokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + # remove the visual model as it is not needed for image generation + self.processor = None + if not self._qwen_image_keep_visual: + text_encoder.model.visual = None + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKLQwenImage.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) + + self.noise_scheduler = QwenImageModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + if self._qwen_image_keep_visual: + try: + self.processor = Qwen2VLProcessor.from_pretrained( + model_path, subfolder="processor" + ) + except OSError: + self.processor = Qwen2VLProcessor.from_pretrained( + base_model_path, subfolder="processor" + ) + kwargs["processor"] = self.processor + + pipe: QwenImagePipeline = self._qwen_pipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImagePipeline = QwenImagePipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + control_img = None + if gen_config.ctrl_img is not None: + raise NotImplementedError( + "Control image generation is not supported in Qwen Image model... yet" + ) + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + self.model.to(self.device_torch) + + # flush for low vram if we are doing that + flush_between_steps = self.model_config.low_vram + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + self.model.to(self.device_torch) + batch_size, num_channels_latents, height, width = latent_model_input.shape + + ps = self.transformer.config.patch_size + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // ps, ps, width // ps, ps + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps) + ) + + # img_shapes passed to the model + img_h2, img_w2 = height // ps, width // ps + img_shapes = [[(1, img_h2, img_w2)]] * batch_size + + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to( + self.device_torch, self.torch_dtype + ).detach(), + timestep=(timestep / 1000).detach(), + guidance=None, + encoder_hidden_states=enc_hs.detach(), + encoder_hidden_states_mask=prompt_embeds_mask.detach(), + img_shapes=img_shapes, + return_dict=False, + **kwargs, + )[0] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // ps, width // ps, num_channels_latents, ps, ps + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + device=self.device_torch, + num_images_per_prompt=1, + ) + # diffusers >=0.37 returns None when all tokens are valid (no padding) + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones( + prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64 + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: QwenImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "qwen_image" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["transformer_blocks"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + images = torch.stack(image_list).to(device, dtype=dtype) + # it uses wan vae, so add dim for frame count + + images = images.unsqueeze(2) + latents = self.vae.encode(images).latent_dist.sample() + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + + latents = (latents - latents_mean) * latents_std + latents = latents.to(device, dtype=dtype) + + latents = latents.squeeze(2) # remove the frame count dimension + + return latents + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + + latents = latents.to(device, dtype=dtype) + + # add frame count dim for wan vae + latents = latents.unsqueeze(2) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + + images = self.vae.decode(latents).sample + + images = images.squeeze(2) # remove the frame count dimension + + return images.to(device, dtype=dtype) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..6061889f4248d9fb36cf73128af1f5e83e1ec080 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py @@ -0,0 +1,280 @@ +import math +import torch +from .qwen_image import QwenImageModel +import os +from typing import TYPE_CHECKING, List, Optional +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F + +from diffusers import ( + QwenImagePipeline, + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from tqdm import tqdm + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from diffusers import QwenImageEditPipeline +except ImportError: + raise ImportError( + "QwenImageEditPipeline not found. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +class QwenImageEditModel(QwenImageModel): + arch = "qwen_image_edit" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPipeline = QwenImageEditPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + control_img = None + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + + # flush for low vram if we are doing that + flush_between_steps = self.model_config.low_vram + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_img, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + self.vae.to(self.device_torch) + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to( + self.vae_device_torch, dtype=self.torch_dtype + ) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if ( + control_tensor.shape[2] != target_h + or control_tensor.shape[3] != target_w + ): + control_tensor = F.interpolate( + control_tensor, size=(target_h, target_w), mode="bilinear" + ) + + control_latent = self.encode_images(control_tensor).to( + latents.device, latents.dtype + ) + latents = torch.cat((latents, control_latent), dim=1) + + return latents.detach() + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + if control_images is not None: + # control images are 0 - 1 scale, shape (bs, ch, height, width) + # images are always run through at 1MP, based on diffusers inference code. + target_area = 1024 * 1024 + ratio = control_images.shape[2] / control_images.shape[3] + height = math.sqrt(target_area * ratio) + width = height / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + control_images = F.interpolate( + control_images, size=(height, width), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + # diffusers >=0.37 returns None when all tokens are valid (no padding) + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones( + prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64 + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + # control is stacked on channels, move it to the batch dimension for packing + latent_model_input, control = torch.chunk(latent_model_input, 2, 1) + + batch_size, num_channels_latents, height, width = latent_model_input.shape + ( + control_batch_size, + control_num_channels_latents, + control_height, + control_width, + ) = control.shape + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + # pack control + control = control.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + control = control.permute(0, 2, 4, 1, 3, 5) + control = control.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + img_h2, img_w2 = height // 2, width // 2 + control_img_h2, control_img_w2 = control_height // 2, control_width // 2 + + img_shapes = [[(1, img_h2, img_w2), (1, control_img_h2, control_img_w2)]] * batch_size + + latents = latent_model_input + latent_model_input = torch.cat([latent_model_input, control], dim=1) + batch_size = latent_model_input.shape[0] + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=enc_hs, + encoder_hidden_states_mask=prompt_embeds_mask.detach(), + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5a376fd9f937e5a286fbd308ca939eeb31bc9f --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -0,0 +1,372 @@ +import math +import torch +from .qwen_image import QwenImageModel +import os +from typing import TYPE_CHECKING, List, Optional +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F + +from diffusers import ( + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from tqdm import tqdm + + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from .qwen_image_pipelines import QwenImageEditPlusCustomPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import ( + CONDITION_IMAGE_SIZE, + VAE_IMAGE_SIZE, + ) +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" + ) + + +class QwenImageEditPlusModel(QwenImageModel): + arch = "qwen_image_edit_plus" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPlusCustomPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = True + # do not resize control images + self.use_raw_control_images = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPlusCustomPipeline = QwenImageEditPlusCustomPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPlusCustomPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + control_img_list = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + elif gen_config.ctrl_img_1 is not None: + control_img = Image.open(gen_config.ctrl_img_1) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + if gen_config.ctrl_img_2 is not None: + control_img = Image.open(gen_config.ctrl_img_2) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_3 is not None: + control_img = Image.open(gen_config.ctrl_img_3) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + # flush for low vram if we are doing that + # flush_between_steps = self.model_config.low_vram + flush_between_steps = False + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_img_list, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + do_cfg_norm=gen_config.do_cfg_norm, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + # we get the control image from the batch + return latents.detach() + + def get_prompt_embeds(self, prompt: List, control_images=None) -> PromptEmbeds: + # todo handle not caching text encoder + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + if control_images is None: + raise ValueError("Missing control images for QwenImageEditPlusModel") + + if not isinstance(control_images, list): + control_images = [control_images] + + # expects a list of list of control images List[List[Tensor]] where each item corresponds to a batch item, + # and each item in the inner list corresponds to a control image for that batch item. + # for single image/caching, it may come in as just List[Tensor], so we handle that case by wrapping it in another list + if not isinstance(control_images[0], list): + control_images = [control_images] + + if len(prompt) != len(control_images): + raise ValueError("Number of prompts must match number of control image sets") + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for b in range(len(prompt)): + batch_control_images = control_images[b] + + for i in range(len(batch_control_images)): + if len(batch_control_images[i].shape) == 3: + batch_control_images[i] = batch_control_images[i].unsqueeze(0) + # control images are 0 - 1 scale, shape (bs, ch, height, width) + ratio = batch_control_images[i].shape[2] / batch_control_images[i].shape[3] + height = math.sqrt(CONDITION_IMAGE_SIZE * ratio) + width = height / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + batch_control_images[i] = F.interpolate( + batch_control_images[i], size=(height, width), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=batch_control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + # diffusers >=0.37 returns None when all tokens are valid (no padding) + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones( + prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64 + ) + prompt_embeds_list.append(prompt_embeds) + prompt_embeds_mask_list.append(prompt_embeds_mask) + pe = PromptEmbeds(torch.cat(prompt_embeds_list, dim=0)) + pe.attention_mask = torch.cat(prompt_embeds_mask_list, dim=0) + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + batch_size, num_channels_latents, height, width = latent_model_input.shape + if self.vae.device != self.device_torch: + self.vae.to(self.device_torch) + + control_image_res = VAE_IMAGE_SIZE + if self.model_config.model_kwargs.get("match_target_res", False): + # use the current target size to set the control image res + control_image_res = height * self.pipeline.vae_scale_factor * width * self.pipeline.vae_scale_factor + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + raw_packed_latents = latent_model_input + + img_h2, img_w2 = height // 2, width // 2 + + # build distinct instances per batch item, per mamad8 + img_shapes = [[(1, img_h2, img_w2)] for _ in range(batch_size)] + + # pack controls + if batch is None: + raise ValueError("Batch is required for QwenImageEditPlusModel") + + # split the latents into batch items so we can concat the controls + packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0) + packed_latents_with_controls_list = [] + + batch_control_tensor_list = batch.control_tensor_list + if batch_control_tensor_list is None and batch.control_tensor is not None: + batch_control_tensor_list = [] + for b in range(batch_size): + batch_control_tensor_list.append(batch.control_tensor[b : b + 1]) + + if batch_control_tensor_list is not None: + b = 0 + for control_tensor_list in batch_control_tensor_list: + # control tensor list is a list of tensors for this batch item + controls = [] + # pack control + for control_img in control_tensor_list: + # control images are 0 - 1 scale, shape (1, ch, height, width) + control_img = control_img.to( + self.device_torch, dtype=self.torch_dtype + ) + # if it is only 3 dim, add batch dim + if len(control_img.shape) == 3: + control_img = control_img.unsqueeze(0) + ratio = control_img.shape[2] / control_img.shape[3] + c_height = math.sqrt(control_image_res * ratio) + c_width = c_height / ratio + + c_width = round(c_width / 32) * 32 + c_height = round(c_height / 32) * 32 + + control_img = F.interpolate( + control_img, size=(c_height, c_width), mode="bilinear" + ) + + # scale to -1 to 1 + control_img = control_img * 2 - 1 + + control_latent = self.encode_images( + control_img, + device=self.device_torch, + dtype=self.torch_dtype, + ) + + clb, cl_num_channels_latents, cl_height, cl_width = ( + control_latent.shape + ) + + control = control_latent.view( + 1, + cl_num_channels_latents, + cl_height // 2, + 2, + cl_width // 2, + 2, + ) + control = control.permute(0, 2, 4, 1, 3, 5) + control = control.reshape( + 1, + (cl_height // 2) * (cl_width // 2), + num_channels_latents * 4, + ) + + img_shapes[b].append((1, cl_height // 2, cl_width // 2)) + controls.append(control) + + # stack controls on dim 1 + control = torch.cat(controls, dim=1).to( + packed_latents_list[b].device, + dtype=packed_latents_list[b].dtype, + ) + # concat with latents + packed_latents_with_control = torch.cat( + [packed_latents_list[b], control], dim=1 + ) + + packed_latents_with_controls_list.append( + packed_latents_with_control + ) + + b += 1 + + latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0) + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to( + self.device_torch, self.torch_dtype + ).detach(), + timestep=(timestep / 1000).detach(), + guidance=None, + encoder_hidden_states=enc_hs.detach(), + encoder_hidden_states_mask=prompt_embeds_mask.detach(), + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + noise_pred = noise_pred[:, : raw_packed_latents.size(1)] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6a9b1d0d9a12bb6c499a6e0de704b022d4100c --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/qwen_image/qwen_image_pipelines.py @@ -0,0 +1,354 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch + +try: + from diffusers import QwenImageEditPlusPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import ( + CONDITION_IMAGE_SIZE, + VAE_IMAGE_SIZE, + XLA_AVAILABLE, + logger, + calculate_dimensions, + calculate_shift, + retrieve_timesteps, + ) +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" + ) + +from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput + + +class QwenImageEditPlusCustomPipeline(QwenImageEditPlusPipeline): + @torch.no_grad() + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + do_cfg_norm: bool = False, + ): + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not ( + isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels + ): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions( + VAE_IMAGE_SIZE, image_width / image_height + ) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append( + self.image_processor.resize(img, condition_height, condition_width) + ) + vae_images.append( + self.image_processor.preprocess( + img, vae_height, vae_width + ).unsqueeze(2) + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None + and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + ( + 1, + height // self.vae_scale_factor // 2, + width // self.vae_scale_factor // 2, + ), + *[ + ( + 1, + vae_height // self.vae_scale_factor // 2, + vae_width // self.vae_scale_factor // 2, + ) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full( + [1], guidance_scale, device=device, dtype=torch.float32 + ) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = ( + prompt_embeds_mask.sum(dim=1).tolist() + if prompt_embeds_mask is not None + else None + ) + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() + if negative_prompt_embeds_mask is not None + else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) + + if do_cfg_norm: + # the official code does this, but I find it hurts more often than it helps, leaving it optional but off by default + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + noise_pred = comb_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/wan22/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/wan22/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61b5e803cab8bf6f4fa8e531108d6a8360db0b34 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/wan22/__init__.py @@ -0,0 +1,3 @@ +from .wan22_5b_model import Wan225bModel +from .wan22_14b_model import Wan2214bModel +from .wan22_14b_i2v_model import Wan2214bI2VModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py new file mode 100644 index 0000000000000000000000000000000000000000..32eb11e8a60195ccb6d006a7bbccbc7a7297e778 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -0,0 +1,144 @@ +import torch +from toolkit.models.wan21.wan_utils import add_first_frame_conditioning +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +import torch +from toolkit.config_modules import GenerateImageConfig +from .wan22_pipeline import Wan22Pipeline + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from diffusers import WanImageToVideoPipeline +from torchvision.transforms import functional as TF + +from .wan22_14b_model import Wan2214bModel + +class Wan2214bI2VModel(Wan2214bModel): + arch = "wan22_14b_i2v" + + + def generate_single_image( + self, + pipeline: Wan22Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + # todo + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + + num_frames = ( + (gen_config.num_frames - 1) // 4 + ) * 4 + 1 # make sure it is divisible by 4 + 1 + gen_config.num_frames = num_frames + + height = gen_config.height + width = gen_config.width + first_frame_n1p1 = None + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + + d = self.get_bucket_divisibility() + + # make sure they are divisible by d + height = height // d * d + width = width // d * d + + # resize the control image + control_img = control_img.resize((width, height), Image.LANCZOS) + + # 5. Prepare latent variables + # num_channels_latents = self.transformer.config.in_channels + num_channels_latents = 16 + latents = pipeline.prepare_latents( + 1, + num_channels_latents, + height, + width, + gen_config.num_frames, + torch.float32, + self.device_torch, + generator, + None, + ).to(self.torch_dtype) + + first_frame_n1p1 = ( + TF.to_tensor(control_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + * 2.0 + - 1.0 + ) # normalize to [-1, 1] + + # Add conditioning using the standalone function + gen_config.latents = add_first_frame_conditioning( + latent_model_input=latents, + first_frame=first_frame_n1p1, + vae=self.vae + ) + + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + height=height, + width=width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra, + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs + ): + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + + # Add conditioning using the standalone function + conditioned_latent = add_first_frame_conditioning( + latent_model_input=latent_model_input, + first_frame=first_frames, + vae=self.vae + ) + + noise_pred = self.model( + hidden_states=conditioned_latent, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + return_dict=False, + **kwargs + )[0] + return noise_pred \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py new file mode 100644 index 0000000000000000000000000000000000000000..aef30b7da27d6dfa45086bf01ad5c6f45b4c37ab --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -0,0 +1,594 @@ +from functools import partial +import os +from typing import Any, Dict, Optional, Union, List +from typing_extensions import Self +import torch +import yaml +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.models.wan21.wan_utils import add_first_frame_conditioning +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +from diffusers import UniPCMultistepScheduler +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.util.quantize import quantize_model +from .wan22_pipeline import Wan22Pipeline +from diffusers import WanTransformer3DModel + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from torchvision.transforms import functional as TF + +from toolkit.models.wan21.wan21 import Wan21 +from .wan22_5b_model import ( + scheduler_config, + time_text_monkeypatch, +) +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file, save_file + + +boundary_ratio_t2v = 0.875 +boundary_ratio_i2v = 0.9 + +scheduler_configUniPC = { + "_class_name": "UniPCMultistepScheduler", + "_diffusers_version": "0.35.0.dev0", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": 3.0, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "time_shift_type": "exponential", + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False, +} + + +class DualWanTransformer3DModel(torch.nn.Module): + def __init__( + self, + transformer_1: WanTransformer3DModel, + transformer_2: WanTransformer3DModel, + torch_dtype: Optional[Union[str, torch.dtype]] = None, + device: Optional[Union[str, torch.device]] = None, + boundary_ratio: float = boundary_ratio_t2v, + low_vram: bool = False, + ) -> None: + super().__init__() + self.transformer_1: WanTransformer3DModel = transformer_1 + self.transformer_2: WanTransformer3DModel = transformer_2 + self.torch_dtype: torch.dtype = torch_dtype + self.device_torch: torch.device = device + self.boundary_ratio: float = boundary_ratio + self.boundary: float = self.boundary_ratio * 1000 + self.low_vram: bool = low_vram + self._active_transformer_name = "transformer_1" # default to transformer_1 + + @property + def device(self) -> torch.device: + return self.device_torch + + @property + def dtype(self) -> torch.dtype: + return self.torch_dtype + + @property + def config(self): + return self.transformer_1.config + + @property + def transformer(self) -> WanTransformer3DModel: + return getattr(self, self._active_transformer_name) + + def enable_gradient_checkpointing(self): + """ + Enable gradient checkpointing for both transformers. + """ + self.transformer_1.enable_gradient_checkpointing() + self.transformer_2.enable_gradient_checkpointing() + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + **kwargs + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + # determine if doing high noise or low noise by meaning the timestep. + # timesteps are in the range of 0 to 1000, so we can use a threshold + with torch.no_grad(): + if timestep.float().mean().item() > self.boundary: + t_name = "transformer_1" + else: + t_name = "transformer_2" + + # check if we are changing the active transformer, if so, we need to swap the one in + # vram if low_vram is enabled + # todo swap the loras as well + if t_name != self._active_transformer_name: + if self.low_vram: + getattr(self, self._active_transformer_name).to("cpu") + getattr(self, t_name).to(self.device_torch) + torch.cuda.empty_cache() + self._active_transformer_name = t_name + + if self.transformer.device != hidden_states.device: + if self.low_vram: + # move other transformer to cpu + other_tname = ( + "transformer_1" if t_name == "transformer_2" else "transformer_2" + ) + getattr(self, other_tname).to("cpu") + + self.transformer.to(hidden_states.device) + + return self.transformer( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + ) + + def to(self, *args, **kwargs) -> Self: + # do not do to, this will be handled separately + return self + + +class Wan2214bModel(Wan21): + arch = "wan22_14b" + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = False + _wan_vae_path = "ai-toolkit/wan2.1-vae" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device=device, + model_config=model_config, + dtype=dtype, + custom_pipeline=custom_pipeline, + noise_scheduler=noise_scheduler, + **kwargs, + ) + # target it so we can target both transformers + self.target_lora_modules = ["DualWanTransformer3DModel"] + self._wan_cache = None + + self.is_multistage = True + # multistage boundaries split the models up when sampling timesteps + # for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2 + self.multistage_boundaries: List[float] = [0.875, 0.0] + + self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True) + self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True) + + self.trainable_multistage_boundaries: List[int] = [] + if self.train_high_noise: + self.trainable_multistage_boundaries.append(0) + if self.train_low_noise: + self.trainable_multistage_boundaries.append(1) + + if len(self.trainable_multistage_boundaries) == 0: + raise ValueError( + "At least one of train_high_noise or train_low_noise must be True in model.model_kwargs" + ) + + # if we are only training one or the other, the target LoRA modules will be the wan transformer class + if not self.train_high_noise or not self.train_low_noise: + self.target_lora_modules = ["WanTransformer3DModel"] + + @property + def max_step_saves_to_keep_multiplier(self): + # the cleanup mechanism checks this to see how many saves to keep + # if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep + if ( + self.network is not None + and self.network.network_config.split_multistage_loras + ): + return 2 + return 1 + + def load_model(self): + # load model from patent parent. Wan21 not immediate parent + # super().load_model() + super().load_model() + + # we have to split up the model on the pipeline + self.pipeline.transformer = self.model.transformer_1 + self.pipeline.transformer_2 = self.model.transformer_2 + + # patch the condition embedder + self.model.transformer_1.condition_embedder.forward = partial( + time_text_monkeypatch, self.model.transformer_1.condition_embedder + ) + self.model.transformer_2.condition_embedder.forward = partial( + time_text_monkeypatch, self.model.transformer_2.condition_embedder + ) + + def get_bucket_divisibility(self): + # 8x compression and 2x2 patch size + return 16 + + def load_wan_transformer(self, transformer_path, subfolder=None): + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for Wan2.2 models" + ) + + if ( + self.model_config.assistant_lora_path is not None + or self.model_config.inference_lora_path is not None + ): + raise ValueError( + "Assistant LoRA is not supported for Wan2.2 models currently" + ) + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for Wan2.2 models currently" + ) + + # transformer path can be a directory that ends with /transformer or a hf path. + + transformer_path_1 = transformer_path + subfolder_1 = subfolder + + transformer_path_2 = transformer_path + subfolder_2 = subfolder + + if subfolder_2 is None: + # we have a local path, replace it with transformer_2 folder + transformer_path_2 = os.path.join( + os.path.dirname(transformer_path_1), "transformer_2" + ) + else: + # we have a hf path, replace it with transformer_2 subfolder + subfolder_2 = "transformer_2" + + self.print_and_status_update("Loading transformer 1") + dtype = self.torch_dtype + transformer_1 = WanTransformer3DModel.from_pretrained( + transformer_path_1, + subfolder=subfolder_1, + torch_dtype=dtype, + ).to(dtype=dtype) + + flush() + + if self.model_config.low_vram: + # quantize on the device + transformer_1.to('cpu', dtype=dtype) + flush() + else: + transformer_1.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: + # todo handle two ARAs + self.print_and_status_update("Quantizing Transformer 1") + quantize_model(self, transformer_1) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer 1 to CPU") + transformer_1.to("cpu") + else: + transformer_1.to(self.device_torch) + + self.print_and_status_update("Loading transformer 2") + dtype = self.torch_dtype + transformer_2 = WanTransformer3DModel.from_pretrained( + transformer_path_2, + subfolder=subfolder_2, + torch_dtype=dtype, + ).to(dtype=dtype) + + flush() + + if self.model_config.low_vram: + # quantize on the device + transformer_2.to('cpu', dtype=dtype) + flush() + else: + transformer_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: + # todo handle two ARAs + self.print_and_status_update("Quantizing Transformer 2") + quantize_model(self, transformer_2) + flush() + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer 2 to CPU") + transformer_2.to("cpu") + else: + transformer_2.to(self.device_torch) + + layer_offloading_transformer = self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0 + # make the combined model + self.print_and_status_update("Creating DualWanTransformer3DModel") + transformer = DualWanTransformer3DModel( + transformer_1=transformer_1, + transformer_2=transformer_2, + torch_dtype=self.torch_dtype, + device=self.device_torch, + boundary_ratio=boundary_ratio_t2v, + low_vram=self.model_config.low_vram, + ) + + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None: + # apply the accuracy recovery adapter to both transformers + self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers") + quantize_model(self, transformer) + flush() + + + if layer_offloading_transformer: + MemoryManager.attach( + transformer_1, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[transformer_1.scale_shift_table] + [block.scale_shift_table for block in transformer_1.blocks] + ) + MemoryManager.attach( + transformer_2, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[transformer_2.scale_shift_table] + [block.scale_shift_table for block in transformer_2.blocks] + ) + + return transformer + + def get_generation_pipeline(self): + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() + pipeline = Wan22Pipeline( + vae=self.vae, + transformer=self.model.transformer_1, + transformer_2=self.model.transformer_2, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, + device=self.device_torch, + aggressive_offload=self.model_config.low_vram, + # todo detect if it is i2v or t2v + boundary_ratio=boundary_ratio_t2v, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def get_base_model_version(self): + return "wan_2.2_14b" + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs, + ): + # todo do we need to override this? Adjust timesteps? + return super().get_noise_prediction( + latent_model_input=latent_model_input, + timestep=timestep, + text_embeddings=text_embeddings, + batch=batch, + **kwargs, + ) + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model) + transformer_combo.transformer_1.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + transformer_combo.transformer_2.save_pretrained( + save_directory=os.path.join(output_path, "transformer_2"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def save_lora( + self, + state_dict: Dict[str, torch.Tensor], + output_path: str, + metadata: Optional[Dict[str, Any]] = None, + ): + if not self.network.network_config.split_multistage_loras: + # just save as a combo lora + save_file(state_dict, output_path, metadata=metadata) + return + + # we need to build out both dictionaries for high and low noise LoRAs + high_noise_lora = {} + low_noise_lora = {} + + only_train_high_noise = self.train_high_noise and not self.train_low_noise + only_train_low_noise = self.train_low_noise and not self.train_high_noise + + for key in state_dict: + if ".transformer_1." in key or only_train_high_noise: + # this is a high noise LoRA + new_key = key.replace(".transformer_1.", ".") + high_noise_lora[new_key] = state_dict[key] + elif ".transformer_2." in key or only_train_low_noise: + # this is a low noise LoRA + new_key = key.replace(".transformer_2.", ".") + low_noise_lora[new_key] = state_dict[key] + + # loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors + if len(high_noise_lora.keys()) > 0: + # save the high noise LoRA + high_noise_lora_path = output_path.replace( + ".safetensors", "_high_noise.safetensors" + ) + save_file(high_noise_lora, high_noise_lora_path, metadata=metadata) + + if len(low_noise_lora.keys()) > 0: + # save the low noise LoRA + low_noise_lora_path = output_path.replace( + ".safetensors", "_low_noise.safetensors" + ) + save_file(low_noise_lora, low_noise_lora_path, metadata=metadata) + + def load_lora(self, file: str): + # if it doesnt have high_noise or low_noise, it is a combo LoRA + if ( + "_high_noise.safetensors" not in file + and "_low_noise.safetensors" not in file + ): + # this is a combined LoRA, we dont need to split it up + sd = load_file(file) + return sd + + # we may have been passed the high_noise or the low_noise LoRA path, but we need to load both + high_noise_lora_path = file.replace( + "_low_noise.safetensors", "_high_noise.safetensors" + ) + low_noise_lora_path = file.replace( + "_high_noise.safetensors", "_low_noise.safetensors" + ) + + combined_dict = {} + + if os.path.exists(high_noise_lora_path) and self.train_high_noise: + # load the high noise LoRA + high_noise_lora = load_file(high_noise_lora_path) + for key in high_noise_lora: + new_key = key.replace( + "diffusion_model.", "diffusion_model.transformer_1." + ) + combined_dict[new_key] = high_noise_lora[key] + if os.path.exists(low_noise_lora_path) and self.train_low_noise: + # load the low noise LoRA + low_noise_lora = load_file(low_noise_lora_path) + for key in low_noise_lora: + new_key = key.replace( + "diffusion_model.", "diffusion_model.transformer_2." + ) + combined_dict[new_key] = low_noise_lora[key] + + # if we are not training both stages, we wont have transformer designations in the keys + if not self.train_high_noise or not self.train_low_noise: + new_dict = {} + for key in combined_dict: + if ".transformer_1." in key: + new_key = key.replace(".transformer_1.", ".") + elif ".transformer_2." in key: + new_key = key.replace(".transformer_2.", ".") + else: + new_key = key + new_dict[new_key] = combined_dict[key] + combined_dict = new_dict + + return combined_dict + + def generate_single_image( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + # todo, figure out how to do video + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + def get_model_to_train(self): + # todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key. + # called when setting up the LoRA. We only need to get the model for the stages we want to train. + if self.train_high_noise and self.train_low_noise: + # we are training both stages, return the unified model + return self.model + elif self.train_high_noise: + # we are only training the high noise stage, return transformer_1 + return self.model.transformer_1 + elif self.train_low_noise: + # we are only training the low noise stage, return transformer_2 + return self.model.transformer_2 + else: + raise ValueError( + "At least one of train_high_noise or train_low_noise must be True in model.model_kwargs" + ) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9d29129bda79584500314eab872102c0819495b5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_5b_model.py @@ -0,0 +1,292 @@ +from functools import partial +import torch +from toolkit.prompt_utils import PromptEmbeds +from PIL import Image +from diffusers import UniPCMultistepScheduler +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from .wan22_pipeline import Wan22Pipeline + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from torchvision.transforms import functional as TF + +from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline +from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22 + + +# for generation only? +scheduler_configUniPC = { + "_class_name": "UniPCMultistepScheduler", + "_diffusers_version": "0.35.0.dev0", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": 5.0, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "time_shift_type": "exponential", + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False, +} + +# for training. I think it is right +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 5.0, + "use_dynamic_shifting": False, +} + +# TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed. +def time_text_monkeypatch( + self, + timestep: torch.Tensor, + encoder_hidden_states, + encoder_hidden_states_image = None, + timestep_seq_len = None, +): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + +class Wan225bModel(Wan21): + arch = "wan22_5b" + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = True + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device=device, + model_config=model_config, + dtype=dtype, + custom_pipeline=custom_pipeline, + noise_scheduler=noise_scheduler, + **kwargs, + ) + + self._wan_cache = None + + def load_model(self): + super().load_model() + + # patch the condition embedder + self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder) + + def get_bucket_divisibility(self): + # 16x compression and 2x2 patch size + return 32 + + def get_generation_pipeline(self): + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() + pipeline = Wan22Pipeline( + vae=self.vae, + transformer=self.model, + transformer_2=self.model, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, + device=self.device_torch, + aggressive_offload=self.model_config.low_vram, + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def get_base_model_version(self): + return "wan_2.2_5b" + + def generate_single_image( + self, + pipeline: AggressiveWanUnloadPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + + num_frames = ( + (gen_config.num_frames - 1) // 4 + ) * 4 + 1 # make sure it is divisible by 4 + 1 + gen_config.num_frames = num_frames + + height = gen_config.height + width = gen_config.width + noise_mask = None + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + + d = self.get_bucket_divisibility() + + # make sure they are divisible by d + height = height // d * d + width = width // d * d + + # resize the control image + control_img = control_img.resize((width, height), Image.LANCZOS) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = pipeline.prepare_latents( + 1, + num_channels_latents, + height, + width, + gen_config.num_frames, + torch.float32, + self.device_torch, + generator, + None, + ).to(self.torch_dtype) + + first_frame_n1p1 = ( + TF.to_tensor(control_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + * 2.0 + - 1.0 + ) # normalize to [-1, 1] + + gen_config.latents, noise_mask = add_first_frame_conditioning_v22( + latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae + ) + + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype + ), + height=height, + width=width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + noise_mask=noise_mask, + **extra, + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs, + ): + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + + # for wan, only do i2v for video for now. Images do normal t2i + conditioned_latent = latent_model_input + noise_mask = None + + if batch.dataset_config.do_i2v: + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + # Add conditioning using the standalone function + conditioned_latent, noise_mask = add_first_frame_conditioning_v22( + latent_model_input=latent_model_input.to( + self.device_torch, self.torch_dtype + ), + first_frame=first_frames.to(self.device_torch, self.torch_dtype), + vae=self.vae, + ) + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + + # make the noise mask + if noise_mask is None: + noise_mask = torch.ones( + conditioned_latent.shape, + dtype=conditioned_latent.dtype, + device=conditioned_latent.device, + ) + # todo write this better + t_chunks = torch.chunk(timestep, timestep.shape[0]) + out_t_chunks = [] + for t in t_chunks: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + temp_ts = temp_ts.unsqueeze(0) + out_t_chunks.append(temp_ts) + timestep = torch.cat(out_t_chunks, dim=0) + + noise_pred = self.model( + hidden_states=conditioned_latent, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + return_dict=False, + **kwargs, + )[0] + return noise_pred diff --git a/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d119343c9cc03235f5a2bb4593e843a2a6fbb939 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -0,0 +1,336 @@ + +import torch +from toolkit.basic import flush +from transformers import AutoTokenizer, UMT5EncoderModel +from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from typing import List +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.image_processor import PipelineImageInput + + +class Wan22Pipeline(WanPipeline): + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + device: torch.device = torch.device("cuda"), + aggressive_offload: bool = False, + ): + super().__init__( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + transformer_2=transformer_2, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + vae=vae, + scheduler=scheduler, + ) + self._aggressive_offload = aggressive_offload + self._exec_device = device + @property + def _execution_device(self): + return self._exec_device + + def __call__( + self: WanPipeline, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], + PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + noise_mask: Optional[torch.Tensor] = None, + ): + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + + width = width // (self.vae.config.scale_factor_spatial * 2) * (self.vae.config.scale_factor_spatial * 2) + height = height // (self.vae.config.scale_factor_spatial * 2) * (self.vae.config.scale_factor_spatial * 2) + + # unload vae and transformer + vae_device = self.vae.device + transformer_device = self.transformer.device + text_encoder_device = self.text_encoder.device + device = self._exec_device + + if self._aggressive_offload: + print("Unloading vae") + self.vae.to("cpu") + print("Unloading transformer") + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to("cpu") + self.text_encoder.to(device) + flush() + + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2 + ) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if self._aggressive_offload: + # unload text encoder + print("Unloading text encoder") + self.text_encoder.to("cpu") + self.transformer.to(device) + flush() + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(device, transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + device, transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + + conditioning = None # wan2.2 i2v conditioning + # check shape of latents to see if it is first frame conditioned for 2.2 14b i2v + if latents is not None: + if latents.shape[1] == 36: + # first 16 channels are latent. other 20 are conditioning + conditioning = latents[:, 16:] + latents = latents[:, :16] + + # we need to trick the in_channls to think it is only 16 channels + num_channels_latents = 16 + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = noise_mask + if mask is None: + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + current_model = self.transformer + + if self._aggressive_offload: + # we don't have one loaded yet in aggressive offload mode + current_model = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + if self._aggressive_offload and current_model != self.transformer: + if self.transformer_2 is not None: + self.transformer_2.to("cpu") + self.transformer.to(device) + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + if self._aggressive_offload and current_model != self.transformer_2: + if self.transformer is not None: + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to(device) + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = latents.to(device, transformer_dtype) + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + pre_condition_latent_model_input = latent_model_input.clone() + + if conditioning is not None: + # conditioning is first frame conditioning for 2.2 i2v + latent_model_input = torch.cat( + [latent_model_input, conditioning], dim=1) + + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * \ + (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + + # apply i2v mask + latents = (pre_condition_latent_model_input * (1 - mask)) + ( + latents * mask + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if self._aggressive_offload: + # unload transformer + print("Unloading transformer") + self.transformer.to("cpu") + if self.transformer_2 is not None: + self.transformer_2.to("cpu") + # load vae + print("Loading Vae") + self.vae.to(vae_device) + flush() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video( + video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + # move transformer back to device + if self._aggressive_offload: + # print("Moving transformer back to device") + # self.transformer.to(self._execution_device) + flush() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/z_image/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/z_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93b76974ef7b79322c8ab7271439e30c91d3552b --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/z_image/__init__.py @@ -0,0 +1,2 @@ +from .z_image import ZImageModel +from .z_image_l2p_model import ZImageL2PModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/z_image/z_image.py b/ai-toolkit/extensions_built_in/diffusion_models/z_image/z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..a9944d877c336009b307337516d40d7c95794dca --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/z_image/z_image.py @@ -0,0 +1,402 @@ +import os +from typing import List, Optional + +import huggingface_hub +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file + +from transformers import AutoTokenizer, Qwen3ForCausalLM +from diffusers import AutoencoderKL + +try: + from diffusers import ZImagePipeline + from diffusers.models.transformers import ZImageTransformer2DModel +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + + +class ZImageModel(BaseModel): + arch = "zimage" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["ZImageTransformer2DModel"] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 8 * 2 # 8 for the VAE, 2 for patch size + + def load_training_adapter(self, transformer: ZImageTransformer2DModel): + self.print_and_status_update("Loading assistant LoRA") + lora_path = self.model_config.assistant_lora_path + if not os.path.exists(lora_path): + # assume it is a hub path + lora_splits = lora_path.split("/") + if len(lora_splits) != 3: + raise ValueError( + f"Assistant LoRA path {lora_path} is not a valid local path or hub path." + ) + repo_id = "/".join(lora_splits[:2]) + filename = lora_splits[2] + try: + lora_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + ) + # upgrade path to + self.model_config.assistant_lora_path = lora_path + except Exception as e: + raise ValueError( + f"Failed to download assistant LoRA from {lora_path}: {e}" + ) + # load the adapter and merge it in. We will inference with a -1.0 multiplier so the adapter effects only work during training. + lora_state_dict = load_file(lora_path) + dim = int( + lora_state_dict[ + "diffusion_model.layers.0.attention.to_k.lora_A.weight" + ].shape[0] + ) + + new_sd = {} + for key, value in lora_state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + lora_state_dict = new_sd + + network_config = { + "type": "lora", + "linear": dim, + "linear_alpha": dim, + "transformer_only": True, + } + + network_config = NetworkConfig(**network_config) + LoRASpecialNetwork.LORA_PREFIX_UNET = "lora_transformer" + network = LoRASpecialNetwork( + text_encoder=None, + unet=transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_transformer=True, + target_lin_modules=self.target_lora_modules, + is_assistant_adapter=True, + is_ara=True, + ) + network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + self.print_and_status_update("Merging in assistant LoRA") + network.force_to(self.device_torch, dtype=self.torch_dtype) + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + + network.merge_in(merge_weight=1.0) + + # mark it as not merged so inference ignores it. + network.is_merged_in = False + + # add the assistant so sampler will activate it while sampling + self.assistant_lora: LoRASpecialNetwork = network + + # deactivate lora during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + + # tell the model to invert assistant on inference since we want remove lora effects + self.invert_assistant_lora = True + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = ZImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + + # load assistant lora if specified + if self.model_config.assistant_lora_path is not None: + self.load_training_adapter(transformer) + # set qtype to be float8 if it is qfloat8 + if self.model_config.qtype == "qfloat8": + self.model_config.qtype = "float8" + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + transformer.x_pad_token, + transformer.cap_pad_token, + ] + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) + + self.noise_scheduler = ZImageModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: ZImagePipeline = ZImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ZImageModel.get_train_scheduler() + + pipeline: ZImagePipeline = ZImagePipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ZImagePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + timestep_model_input = (1000 - timestep) / 1000 + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + text_embeddings.text_embeds, + )[0] + + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + ) + pe = PromptEmbeds([prompt_embeds, None]) + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + transformer: ZImageTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, "transformer"), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, "aitk_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "zimage" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/ai-toolkit/extensions_built_in/diffusion_models/z_image/z_image_l2p_model.py b/ai-toolkit/extensions_built_in/diffusion_models/z_image/z_image_l2p_model.py new file mode 100644 index 0000000000000000000000000000000000000000..34fd183d1fb55b821f2250060126a18f9c1dd53c --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/z_image/z_image_l2p_model.py @@ -0,0 +1,559 @@ +from extensions_built_in.diffusion_models.z_image.z_image import ZImageModel +import os +from typing import Dict, List, Optional, Union + +import huggingface_hub +import torch +import torch.nn as nn +import torch.nn.functional as F +import yaml +from toolkit.basic import flush +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager + +from transformers import AutoTokenizer, Qwen3ForCausalLM +from toolkit.models.FakeVAE import FakeVAE +from toolkit.paths import MODELS_PATH +from safetensors.torch import load_file, save_file +from toolkit.metadata import get_meta_for_safetensors + +HF_TOKEN = os.getenv("HF_TOKEN", None) + +try: + from diffusers import ZImagePipeline + from diffusers.models.transformers.transformer_z_image import ( + ZImageTransformer2DModel as ZImageTransformer2DModelOriginal, + ) + from diffusers.models.modeling_outputs import Transformer2DModelOutput +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +# Default ZImage transformer config used when loading from a single safetensors +# file (no config.json available alongside). +ZIMAGE_DEFAULT_CONFIG = { + "_class_name": "ZImageTransformer2DModel", + "_diffusers_version": "0.37.0.dev0", + "all_f_patch_size": [1], + "all_patch_size": [2], + "axes_dims": [32, 48, 48], + "axes_lens": [1536, 512, 512], + "cap_feat_dim": 2560, + "dim": 3840, + "in_channels": 16, + "n_heads": 30, + "n_kv_heads": 30, + "n_layers": 30, + "n_refiner_layers": 2, + "norm_eps": 1e-05, + "qk_norm": True, + "rope_theta": 256.0, + "siglip_feat_dim": None, + "t_scale": 1000.0, +} + + +class MicroDiffusionModel(nn.Module): + """L2P pixel-space decoder: a small 4-stage U-Net that fuses the transformer + feature map at the bottleneck and outputs pixel-space prediction.""" + + def __init__(self, in_channels: int, si_t_hidden_size: int): + super().__init__() + + self.enc1 = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.SiLU() + ) + self.pool1 = nn.MaxPool2d(2, stride=2) + self.enc2 = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.SiLU() + ) + self.pool2 = nn.MaxPool2d(2, stride=2) + self.enc3 = nn.Sequential( + nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.SiLU() + ) + self.pool3 = nn.MaxPool2d(2, stride=2) + self.enc4 = nn.Sequential( + nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.SiLU() + ) + self.pool4 = nn.MaxPool2d(2, stride=2) + + self.bottleneck = nn.Sequential( + nn.Conv2d(512 + si_t_hidden_size, 512, kernel_size=1), + nn.SiLU(), + ) + + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(512, 512, kernel_size=3, padding=1), + ) + self.dec4 = nn.Sequential( + nn.Conv2d(512 + 512, 256, kernel_size=3, padding=1), nn.SiLU() + ) + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + ) + self.dec3 = nn.Sequential( + nn.Conv2d(256 + 256, 128, kernel_size=3, padding=1), nn.SiLU() + ) + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(128, 128, kernel_size=3, padding=1), + ) + self.dec2 = nn.Sequential( + nn.Conv2d(128 + 128, 64, kernel_size=3, padding=1), nn.SiLU() + ) + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + ) + self.dec1 = nn.Sequential( + nn.Conv2d(64 + 64, 64, kernel_size=3, padding=1), nn.SiLU() + ) + self.out_conv = nn.Conv2d(64, in_channels, kernel_size=1) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + e1 = self.enc1(x) + p1 = self.pool1(e1) + e2 = self.enc2(p1) + p2 = self.pool2(e2) + e3 = self.enc3(p2) + p3 = self.pool3(e3) + e4 = self.enc4(p3) + p4 = self.pool4(e4) + + if c.shape[-2:] != p4.shape[-2:]: + c = F.interpolate(c, size=p4.shape[-2:], mode="nearest") + b = self.bottleneck(torch.cat([p4, c.to(p4.dtype)], dim=1)) + + d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1)) + d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1)) + d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1)) + d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1)) + + return self.out_conv(d1) + + +class ZImageTransformer2DModel(ZImageTransformer2DModelOriginal): + """L2P-style ZImage transformer: runs the standard trunk but replaces the + FinalLayer + unpatchify tail with a pixel-space U-Net decoder.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # FinalLayer is unused in L2P — the pixel-space U-Net does the decoding. + # Removing it keeps the module out of state_dict/named_parameters. + if hasattr(self, "all_final_layer"): + del self.all_final_layer + self.local_decoder = MicroDiffusionModel( + in_channels=self.in_channels, + si_t_hidden_size=self.dim, + ) + + def forward( + self, + x: Union[List[torch.Tensor], List[List[torch.Tensor]]], + t, + cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]], + return_dict: bool = True, + controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, + siglip_feats: Optional[List[List[torch.Tensor]]] = None, + image_noise_mask: Optional[List[List[int]]] = None, + patch_size: int = 16, + f_patch_size: int = 1, + ): + assert ( + patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + ) + assert not isinstance(x[0], list), "L2P does not support omni mode" + device = x[0].device + + # Capture original noisy pixel images for the U-Net decoder. + noisy_images = torch.stack(x, dim=0) + if noisy_images.dim() == 5: + noisy_images = noisy_images.squeeze(2) + bsz, _, H_ori, W_ori = noisy_images.shape + + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + + ( + x_patches, + cap_feats_proc, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # X embed & refine + x_seqlens = [len(xi) for xi in x_patches] + x_embed = self.all_x_embedder[f"{patch_size}-{f_patch_size}"]( + torch.cat(x_patches, dim=0) + ) + x_embed, x_freqs, x_mask, _, _ = self._prepare_sequence( + list(x_embed.split(x_seqlens, dim=0)), + x_pos_ids, + x_pad_mask, + self.x_pad_token, + None, + device, + ) + for layer in self.noise_refiner: + x_embed = ( + self._gradient_checkpointing_func( + layer, x_embed, x_mask, x_freqs, adaln_input, None, None, None + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(x_embed, x_mask, x_freqs, adaln_input, None, None, None) + ) + + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats_proc] + cap_embed = self.cap_embedder(torch.cat(cap_feats_proc, dim=0)) + cap_embed, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_embed.split(cap_seqlens, dim=0)), + cap_pos_ids, + cap_pad_mask, + self.cap_pad_token, + None, + device, + ) + for layer in self.context_refiner: + cap_embed = ( + self._gradient_checkpointing_func(layer, cap_embed, cap_mask, cap_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(cap_embed, cap_mask, cap_freqs) + ) + + # Unified sequence: basic mode = [x, cap] + unified, unified_freqs, unified_mask, _ = self._build_unified_sequence( + x_embed, + x_freqs, + x_seqlens, + None, + cap_embed, + cap_freqs, + cap_seqlens, + None, + None, + None, + None, + None, + False, + device, + ) + + for layer_idx, layer in enumerate(self.layers): + unified = ( + self._gradient_checkpointing_func( + layer, + unified, + unified_mask, + unified_freqs, + adaln_input, + None, + None, + None, + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer( + unified, unified_mask, unified_freqs, adaln_input, None, None, None + ) + ) + if ( + controlnet_block_samples is not None + and layer_idx in controlnet_block_samples + ): + unified = unified + controlnet_block_samples[layer_idx] + + # L2P tail: extract image tokens, reshape to (B, dim, H/p, W/p), decode in pixel space. + feat_H = H_ori // patch_size + feat_W = W_ori // patch_size + img_token_len = feat_H * feat_W + img_features = unified[:, :img_token_len, :] + feat_map = ( + img_features.reshape(bsz, feat_H, feat_W, self.dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + decoded = self.local_decoder(noisy_images, feat_map) + decoded = decoded.unsqueeze(2) # add F=1 axis to match (C, F, H, W) downstream + x_out = list(decoded.unbind(0)) + + return (x_out,) if not return_dict else Transformer2DModelOutput(sample=x_out) + + +class ZImageL2PModel(ZImageModel): + arch = "zimage_l2p" + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + # If model_path looks like "//.../file.safetensors" + # and isn't a local file or directory, resolve it from HF Hub. Cache the + # downloaded file under MODELS_PATH/diffusion_models and reuse it on + # subsequent runs. + if ( + not os.path.isfile(model_path) + and not os.path.isdir(model_path) + and model_path.endswith(".safetensors") + and model_path.count("/") >= 2 + ): + repo_id, filename = model_path.rsplit("/", 1) + target_dir = os.path.join(MODELS_PATH, "diffusion_models") + target_path = os.path.join(target_dir, filename) + if os.path.isfile(target_path): + self.print_and_status_update(f"Using cached weights at {target_path}") + model_path = target_path + else: + os.makedirs(target_dir, exist_ok=True) + self.print_and_status_update(f"Downloading {filename} from {repo_id}") + model_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=target_dir, + token=HF_TOKEN, + ) + + self.print_and_status_update("Loading transformer") + + transformer_path = model_path + transformer_subfolder = "transformer" + + if os.path.isfile(model_path): + # Local single-file checkpoint (e.g. a .safetensors merge). + # No sidecar config.json — build the architecture from the hardcoded + # ZImage default config (or a pixel-space variant if the checkpoint + # already contains L2P / pixel-space keys) and overwrite weights. + # Pull the rest of the pipeline (tokenizer, text encoder) from the + # official Z-Image-Turbo repo since we have no local sidecar config. + self.print_and_status_update( + f"Loading transformer weights from {model_path}" + ) + sd = load_file(model_path) + sd = {k: v.to(dtype) for k, v in sd.items()} + + # Detect a pixel-space (L2P) checkpoint by the presence of either + # the pixel-size x_embedder or local_decoder keys. Also infer + # in_channels from the local_decoder shape so we match whatever + # the checkpoint actually was trained with. + has_l2p_keys = any(k.startswith("local_decoder.") for k in sd) + has_pixel_xemb = "all_x_embedder.16-1.weight" in sd + is_pixel = has_l2p_keys or has_pixel_xemb + + inferred_in_channels = None + if "local_decoder.enc1.0.weight" in sd: + inferred_in_channels = sd["local_decoder.enc1.0.weight"].shape[1] + elif has_pixel_xemb: + inferred_in_channels = 3 + + config = dict(ZIMAGE_DEFAULT_CONFIG) + if is_pixel: + config["in_channels"] = ( + inferred_in_channels if inferred_in_channels else 3 + ) + config["all_patch_size"] = [16] + self.print_and_status_update( + f" detected pixel-space checkpoint (L2P), in_channels={config['in_channels']}" + ) + + # Strip ConfigMixin metadata before passing to the constructor. + init_args = {k: v for k, v in config.items() if not k.startswith("_")} + transformer = ZImageTransformer2DModel(**init_args) + transformer = transformer.to(dtype) + self.print_and_status_update( + f" built transformer: in_channels={transformer.in_channels}, " + f"all_patch_size={transformer.all_patch_size}" + ) + + missing, unexpected = transformer.load_state_dict(sd, strict=False) + if unexpected: + self.print_and_status_update( + f" {len(unexpected)} unexpected keys (e.g. {unexpected[:3]})" + ) + if missing: + self.print_and_status_update( + f" {len(missing)} missing keys kept at init (e.g. {missing[:3]})" + ) + del sd + flush() + else: + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = ZImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) + # convert it to pixel space if needed + if transformer.config.in_channels != 3: + self.print_and_status_update("Converting transformer to pixel space") + old_transformer = transformer + new_config = dict(transformer.config) + new_config["in_channels"] = 3 + new_config["all_patch_size"] = [16] + transformer = ZImageTransformer2DModel.from_config(new_config) + + # update the state dict + old_transformer_state = old_transformer.state_dict() + new_transformer_state_dict = {} + for k, v in old_transformer_state.items(): + if k == "all_x_embedder.2-1.weight": + new_v = torch.randn( + v.shape[0], + 768, + dtype=v.dtype, + device=v.device, + ) + new_transformer_state_dict["all_x_embedder.16-1.weight"] = ( + new_v * 0.001 + ) + elif k.startswith("all_final_layer."): + # FinalLayer is unused in L2P; pixel decoding is done by local_decoder. + continue + else: + new_transformer_state_dict[k] = v + + # local_decoder.* keys are absent from the source checkpoint; they keep + # the random init from MicroDiffusionModel.__init__ via strict=False. + transformer.load_state_dict(new_transformer_state_dict, strict=False) + del old_transformer + del old_transformer_state + del new_transformer_state_dict + flush() + + # load assistant lora if specified + if self.model_config.assistant_lora_path is not None: + self.load_training_adapter(transformer) + # set qtype to be float8 if it is qfloat8 + if self.model_config.qtype == "qfloat8": + self.model_config.qtype = "float8" + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + transformer.x_pad_token, + transformer.cap_pad_token, + ], + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + # vae = AutoencoderKL.from_pretrained( + # base_model_path, subfolder="vae", torch_dtype=dtype + # ) + vae = FakeVAE(scaling_factor=1.0) + + self.noise_scheduler = ZImageModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: ZImagePipeline = ZImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def save_model(self, output_path, meta, save_dtype): + transformer: ZImageTransformer2DModel = unwrap_model(self.model) + if not output_path.endswith(".safetensors"): + output_path += ".safetensors" + meta = get_meta_for_safetensors(meta, name=self.arch) + + sd = transformer.state_dict() + save_dict = {} + for key, value in sd.items(): + # Skip the unused FinalLayer — L2P bypasses it in forward(), so + # the weights are dead and just bloat the checkpoint. + if key.startswith("all_final_layer."): + continue + save_dict[key] = value.to("cpu").to(save_dtype) + save_file(save_dict, output_path, metadata=meta) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/__init__.py b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b906770910a152dc5c3bc7fd38ea190b2218f93e --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/__init__.py @@ -0,0 +1 @@ +from .zeta_chroma_model import ZetaChromaModel \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_model.py b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ca8b491fac822f147415cedd8af20fe83cc52c40 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_model.py @@ -0,0 +1,383 @@ +import os +from typing import List, Optional + +import huggingface_hub +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig, NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze +from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.memory_management import MemoryManager +from safetensors.torch import load_file +from optimum.quanto import QTensor +from toolkit.metadata import get_meta_for_safetensors +from safetensors.torch import load_file, save_file +from transformers import AutoTokenizer, Qwen3ForCausalLM +from diffusers import AutoencoderKL +from toolkit.models.FakeVAE import FakeVAE +from .zeta_chroma_transformer import ZImageDCT, ZImageDCTParams, vae_flatten, vae_unflatten, prepare_latent_image_ids, make_text_position_ids +from .zeta_chroma_pipeline import ZetaChromaPipeline + + + +scheduler_config = { + "num_train_timesteps": 1000, + "use_dynamic_shifting": False, + "shift": 3.0, +} + +ZETA_CHROMA_TRANSFORMER_FILENAME = "zeta-chroma-base-x0-pixel-dino-distance.safetensors" + + +class ZetaChromaModel(BaseModel): + arch = "zeta_chroma" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["ZImageDCT"] + self.patch_size = 32 + self.max_sequence_length = 512 + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return self.patch_size + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = self.model_config.name_or_path + base_model_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + + transformer_path = model_path + if not os.path.exists(transformer_path): + transformer_name = ZETA_CHROMA_TRANSFORMER_FILENAME + # if path ends with .safetensors, assume last part is filename + # this allows users to target different file names in the repo like + # lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors + + if transformer_path.endswith(".safetensors"): + splits = transformer_path.split("/") + transformer_name = splits[-1] + transformer_path = "/".join(splits[:-1]) + # assume it is from the hub + transformer_path = huggingface_hub.hf_hub_download( + repo_id=transformer_path, + filename=transformer_name, + ) + + transformer_state_dict = load_file(transformer_path, device="cpu") + + # cast to dtype + for key in transformer_state_dict: + transformer_state_dict[key] = transformer_state_dict[key].to(dtype) + + # Auto-detect use_x0 from checkpoint + use_x0 = "__x0__" in transformer_state_dict + + # Build model params + in_channels = self.patch_size * self.patch_size * 3 # RGB patches + model_params = ZImageDCTParams( + patch_size=1, + in_channels=in_channels, + use_x0=use_x0, + ) + + with torch.device("meta"): + transformer = ZImageDCT(model_params) + + transformer.load_state_dict(transformer_state_dict, assign=True) + del transformer_state_dict + + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_transformer_percent > 0 + ): + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[ + transformer.x_pad_token, + transformer.cap_pad_token, + ], + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to("cpu") + + flush() + + self.print_and_status_update("Text Encoder") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype + ) + + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Loading VAE") + vae = FakeVAE(scaling_factor=1.0) + vae.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ZetaChromaModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + kwargs = {} + + pipe: ZetaChromaPipeline = ZetaChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ZetaChromaModel.get_train_scheduler() + + pipeline: ZetaChromaPipeline = ZetaChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ZetaChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + self.model.to(self.device_torch) + + do_low_step_schedule = gen_config.num_inference_steps <= 8 and gen_config.guidance_scale <= 1.0 + + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + low_step_schedule=do_low_step_schedule, + **extra, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + if self.model.device == torch.device("cpu"): + self.model.to(self.device_torch) + + with torch.no_grad(): + + pixel_shape = latent_model_input.shape + # todo: do we invert like this? + # t_vec = (1000 - timestep) / 1000 + t_vec = timestep / 1000 + + height = latent_model_input.shape[2] + h_patches = height // self.patch_size + width = latent_model_input.shape[3] + w_patches = width // self.patch_size + batch_size = latent_model_input.shape[0] + + img, _ = vae_flatten(latent_model_input, patch_size=self.patch_size) + + num_patches = img.shape[1] + + # --- Build position IDs --- + pos_lengths = text_embeddings.attention_mask.sum(1) + offset = pos_lengths + + image_pos_ids = prepare_latent_image_ids( + offset, h_patches, w_patches, patch_size=1 + ).to(self.device_torch) + pos_text_ids = make_text_position_ids(pos_lengths, self.max_sequence_length).to( + self.device_torch + ) + img_mask = torch.ones( + (batch_size, num_patches), device=self.device_torch, dtype=torch.bool + ) + + + + # model_out_list = self.transformer( + # latent_model_input_list, + # t_vec, + # text_embeddings.text_embeds, + # )[0] + pred = self.transformer( + img=img, #(1, 1024, 3072) + img_ids=image_pos_ids, # (1, 1024, 3) + img_mask=img_mask, # (1, 1024) + txt=text_embeddings.text_embeds, # (1, 512, 2560) + txt_ids=pos_text_ids, # (1, 512, 3) + txt_mask=text_embeddings.attention_mask, # (1, 512) + timesteps=t_vec, # (1,) + ) + + pred = vae_unflatten(pred.float(), pixel_shape, patch_size=self.patch_size) + + return pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, mask = self.pipeline._encode_prompts( + prompt, + ) + pe = PromptEmbeds([prompt_embeds, None], attention_mask=mask) + + return pe + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: ZImageDCT = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to("cpu", dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name="zeta_chroma") + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get("noise") + batch = kwargs.get("batch") + return (noise - batch.latents).detach() + + def get_base_model_version(self): + return "zeta_chroma" + + def get_transformer_block_names(self) -> Optional[List[str]]: + return ["layers"] + + def convert_lora_weights_before_save(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd diff --git a/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_pipeline.py b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea349316b626685fb5fa313c2c5c7b768f02bba --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_pipeline.py @@ -0,0 +1,185 @@ +from diffusers.pipelines.z_image.pipeline_z_image import ( + ZImagePipeline, + calculate_shift, + retrieve_timesteps, +) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.utils.torch_utils import randn_tensor +import torch +from diffusers.utils import logging, replace_example_docstring +from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput +from extensions_built_in.diffusion_models.zeta_chroma.zeta_chroma_transformer import ( + get_schedule, + get_low_step_schedule, + prepare_latent_image_ids, + make_text_position_ids, + vae_unflatten, +) + + +class ZetaChromaPipeline(ZImagePipeline): + need_something_here = True + patch_size = 32 + max_sequence_length = 512 + + @torch.no_grad() + def _encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode a list of prompts with the Qwen3 chat template.""" + formatted = [] + for p in prompts: + messages = [{"role": "user", "content": p}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + formatted.append(text) + + inputs = self.tokenizer( + formatted, + padding="max_length", + max_length=self.max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(self.text_encoder.device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = self.text_encoder( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + output_hidden_states=True, + ) + + # Second-to-last hidden state (same as training) + embeddings = outputs.hidden_states[-2] + mask = inputs.attention_mask.bool() + return embeddings, mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + prompt_embeds_mask: Optional[torch.BoolTensor] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds_mask: Optional[torch.BoolTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + low_step_schedule: bool = False, + ): + device = self._execution_device + + batch_size = len(prompt_embeds) + device = self._execution_device + patch_size = self.patch_size + in_channels = patch_size * patch_size * 3 + + h_patches = height // patch_size + w_patches = width // patch_size + num_patches = h_patches * w_patches + + pos_embeds, pos_mask = prompt_embeds, prompt_embeds_mask + neg_embeds, neg_mask = negative_prompt_embeds, negative_prompt_embeds_mask + + # --- Build position IDs --- + pos_lengths = pos_mask.sum(1) + neg_lengths = neg_mask.sum(1) + offset = torch.maximum(pos_lengths, neg_lengths) + + image_pos_ids = prepare_latent_image_ids( + offset, h_patches, w_patches, patch_size=1 + ).to(device) + pos_text_ids = make_text_position_ids(pos_lengths, max_sequence_length).to( + device + ) + neg_text_ids = make_text_position_ids(neg_lengths, max_sequence_length).to( + device + ) + + # --- Initial noise --- + noise = randn_tensor( + (batch_size, num_patches, in_channels), + generator=generator, + device=device, + dtype=self.transformer.dtype, + ) + + # --- Timestep schedule --- + if low_step_schedule: + timesteps = get_low_step_schedule(num_inference_steps) + else: + timesteps = get_schedule(num_inference_steps, num_patches) + + # --- Denoising loop (CFG) --- + img = noise + img_mask = torch.ones( + (batch_size, num_patches), device=device, dtype=torch.bool + ) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + + t_vec = torch.full( + (batch_size,), t_curr, dtype=self.dtype, device=device + ) + + pred = self.transformer( + img=img, + img_ids=image_pos_ids, + img_mask=img_mask, + txt=pos_embeds, + txt_ids=pos_text_ids, + txt_mask=pos_mask, + timesteps=t_vec, + ) + + if guidance_scale > 1.0: + pred_neg = self.transformer( + img=img, + img_ids=image_pos_ids, + img_mask=img_mask, + txt=neg_embeds, + txt_ids=neg_text_ids, + txt_mask=neg_mask, + timesteps=t_vec, + ) + pred = pred_neg + guidance_scale * (pred - pred_neg) + + img = img + (t_prev - t_curr) * pred + + progress_bar.update() + + if output_type == "latent": + image = img + + else: + # --- Unpatchify: [B, num_patches, C*P*P] -> [B, 3, H, W] --- + pixel_shape = (batch_size, 3, height, width) + image = vae_unflatten(img.float(), pixel_shape, patch_size=patch_size) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_transformer.py b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2db51daed8f6f9ab6512777ebbb84c7045a67a5 --- /dev/null +++ b/ai-toolkit/extensions_built_in/diffusion_models/zeta_chroma/zeta_chroma_transformer.py @@ -0,0 +1,743 @@ +# orig code provided by lodestones, altered for ai-toolkit + +from dataclasses import dataclass, field +from functools import lru_cache +from typing import List, Optional +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from einops import rearrange +import torch.utils.checkpoint as ckpt + + +@dataclass +class ZImageDCTParams: + patch_size: int = 1 + f_patch_size: int = 1 + in_channels: int = 128 + dim: int = 3840 + n_layers: int = 30 + n_refiner_layers: int = 2 + n_heads: int = 30 + n_kv_heads: int = 30 + norm_eps: float = 1e-5 + qk_norm: bool = True + cap_feat_dim: int = 2560 + rope_theta: int = 256 + t_scale: float = 1000.0 + axes_dims: list = field(default_factory=lambda: [32, 48, 48]) + axes_lens: list = field(default_factory=lambda: [1536, 512, 512]) + adaln_embed_dim: int = 256 + use_x0: bool = True + # DCT decoder params + decoder_hidden_size: int = 3840 + decoder_num_res_blocks: int = 4 + decoder_max_freqs: int = 8 + + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.patch_size = 1 + + +def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype): + if attn_mask is None: + return None + if attn_mask.ndim == 2: + attn_mask = attn_mask[:, None, None, :] + if attn_mask.dtype == torch.bool: + new_mask = torch.zeros_like(attn_mask, dtype=dtype) + new_mask.masked_fill_(~attn_mask, float("-inf")) + return new_mask + return attn_mask + + +def _native_attention_wrapper( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attn_mask = _process_mask(attn_mask, query.dtype) + out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + return out.transpose(1, 2).contiguous() + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + return self.mlp(t_freq) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return output * self.weight + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + +class ZImageAttention(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + qk_norm: bool = True, + eps: float = 1e-5, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = dim // n_heads + + self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False) + self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) + self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) + self.to_out = nn.ModuleList( + [nn.Linear(n_heads * self.head_dim, dim, bias=False)] + ) + + self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None + self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = query.unflatten(-1, (self.n_heads, -1)) + key = key.unflatten(-1, (self.n_kv_heads, -1)) + value = value.unflatten(-1, (self.n_kv_heads, -1)) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + hidden_states = _native_attention_wrapper( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.flatten(2, 3).to(dtype) + return self.to_out[0](hidden_states) + + +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation: bool = True, + adaln_embed_dim: int = 256, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.layer_id = layer_id + self.modulation = modulation + + self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps) + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + if modulation: + self.adaLN_modulation = nn.ModuleList( + [nn.Linear(min(dim, adaln_embed_dim), 4 * dim, bias=True)] + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2) + ) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + x = x + gate_mlp * self.ffn_norm2( + self.feed_forward(self.ffn_norm1(x) * scale_mlp) + ) + else: + attn_out = self.attention( + self.attention_norm1(x), + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(attn_out) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256, + axes_dims: List[int] = None, + axes_lens: List[int] = None, + ): + self.theta = theta + self.axes_dims = axes_dims or [32, 48, 48] + self.axes_lens = axes_lens or [1536, 512, 512] + assert len(self.axes_dims) == len(self.axes_lens) + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256): + with torch.device("cpu"): + freqs_cis = [] + for d, e in zip(dim, end): + freqs = 1.0 / ( + theta + ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) + ) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to( + torch.complex64 + ) + freqs_cis.append(freqs_cis_i) + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim >= 2 and ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis( + self.axes_dims, self.axes_lens, theta=self.theta + ) + self.freqs_cis = [f.to(device) for f in self.freqs_cis] + elif self.freqs_cis[0].device != device: + self.freqs_cis = [f.to(device) for f in self.freqs_cis] + + return torch.cat( + [self.freqs_cis[i][ids[..., i]] for i in range(len(self.axes_dims))], dim=-1 + ) + + +# --- Decoder components --- + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels, hidden_size_input, max_freqs): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels + max_freqs**2, hidden_size_input) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size, device, dtype): + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + freqs = torch.linspace( + 0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device + ) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + coeffs = (1 + freqs_x * freqs_y) ** -1 + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2) + return dct + + def forward(self, inputs): + B, P2, C = inputs.shape + original_dtype = inputs.dtype + with torch.autocast("cuda", enabled=False): + patch_size = int(P2**0.5) + inputs = inputs.float() + dct = self.fetch_pos(patch_size, inputs.device, torch.float32) + dct = dct.repeat(B, 1, 1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder.float()(inputs) + return inputs.to(original_dtype) + + +class ResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 3 * channels, bias=True), + ) + self._init_weights() + + def _init_weights(self): + for m in self.mlp: + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="linear") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class DCTFinalLayer(nn.Module): + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm( + model_channels, elementwise_affine=False, eps=1e-6 + ) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x): + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + patch_size, + max_freqs=8, + ): + super().__init__() + self.patch_size = patch_size + self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels) + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + ) + self.res_blocks = nn.ModuleList( + [ResBlock(model_channels) for _ in range(num_res_blocks)] + ) + self.final_layer = DCTFinalLayer(model_channels, out_channels) + nn.init.xavier_uniform_(self.cond_embed.weight) + nn.init.constant_(self.cond_embed.bias, 0) + + def forward(self, x, c): + x = self.input_embedder(x) + c = self.cond_embed(c) + y = c.reshape(c.shape[0], self.patch_size**2, -1) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x) + + +class ZImageDCT(nn.Module): + def __init__(self, params: ZImageDCTParams): + super().__init__() + self.config = FakeConfig() + self.in_channels = params.in_channels + self.out_channels = params.in_channels + self.patch_size = params.patch_size + self.f_patch_size = params.f_patch_size + self.dim = params.dim + self.n_heads = params.n_heads + self.rope_theta = params.rope_theta + self.t_scale = params.t_scale + self.adaln_embed_dim = params.adaln_embed_dim + + self.x_embedder = nn.Linear( + self.f_patch_size * self.patch_size * self.patch_size * params.in_channels, + params.dim, + bias=True, + ) + + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + i, + params.dim, + params.n_heads, + params.n_kv_heads, + params.norm_eps, + params.qk_norm, + modulation=True, + adaln_embed_dim=params.adaln_embed_dim, + ) + for i in range(params.n_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + i, + params.dim, + params.n_heads, + params.n_kv_heads, + params.norm_eps, + params.qk_norm, + modulation=False, + ) + for i in range(params.n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder( + min(params.dim, params.adaln_embed_dim), mid_size=1024 + ) + + self.cap_embedder = nn.Sequential( + RMSNorm(params.cap_feat_dim, eps=params.norm_eps), + nn.Linear(params.cap_feat_dim, params.dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, params.dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, params.dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock( + i, + params.dim, + params.n_heads, + params.n_kv_heads, + params.norm_eps, + params.qk_norm, + modulation=True, + adaln_embed_dim=params.adaln_embed_dim, + ) + for i in range(params.n_layers) + ] + ) + + head_dim = params.dim // params.n_heads + assert head_dim == sum(params.axes_dims) + self.axes_dims = params.axes_dims + self.axes_lens = params.axes_lens + + self.rope_embedder = RopeEmbedder( + theta=params.rope_theta, + axes_dims=params.axes_dims, + axes_lens=params.axes_lens, + ) + + self.dec_net = SimpleMLPAdaLN( + in_channels=params.in_channels, + model_channels=params.decoder_hidden_size, + out_channels=params.in_channels, + z_channels=params.dim, + num_res_blocks=params.decoder_num_res_blocks, + patch_size=self.patch_size, + max_freqs=params.decoder_max_freqs, + ) + + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward( + self, + img: Tensor, + img_ids: Tensor, + img_mask: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + ): + B = img.shape[0] + num_patches = img.shape[1] + + pixel_values = img.reshape( + B * num_patches, self.patch_size**2, self.in_channels + ) + + timesteps = (1 - timesteps) * self.t_scale + timesteps_embedding = self.t_embedder(timesteps) + + img_hidden = self.x_embedder(img) + txt_hidden = self.cap_embedder(txt) + + img_pe = self.rope_embedder(img_ids) + txt_pe = self.rope_embedder(txt_ids) + + for layer in self.noise_refiner: + img_hidden = layer(img_hidden, img_mask, img_pe, timesteps_embedding) + + for layer in self.context_refiner: + txt_hidden = layer(txt_hidden, txt_mask, txt_pe) + + mixed_hidden = torch.cat((txt_hidden, img_hidden), 1) + mixed_mask = torch.cat((txt_mask, img_mask), 1) + mixed_pe = torch.cat((txt_pe, img_pe), 1) + + for layer in self.layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + mixed_hidden = ckpt.checkpoint( + layer, + mixed_hidden, + mixed_mask, + mixed_pe, + timesteps_embedding, + use_reentrant=False, + ) + else: + mixed_hidden = layer( + mixed_hidden, mixed_mask, mixed_pe, timesteps_embedding + ) + + img_hidden = mixed_hidden[:, txt.shape[1] :, ...] + + decoder_condition = img_hidden.reshape(B * num_patches, self.dim) + output = self.dec_net(pixel_values, decoder_condition) + output = output.reshape(B, num_patches, -1) + + return -output + + def _apply_x0_residual(self, predicted, noisy, timesteps): + return (noisy - predicted) / timesteps.view(-1, 1, 1) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + img_mask: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + ): + out = self._forward( + img=img, + img_ids=img_ids, + img_mask=img_mask, + txt=txt, + txt_ids=txt_ids, + txt_mask=txt_mask, + timesteps=timesteps, + ) + if hasattr(self, "__x0__"): + return self._apply_x0_residual(out, img, timesteps) + return out + + +def vae_flatten(latents, patch_size=2): + """Patchify: [N, C, H, W] -> ([N, num_patches, patch_size*patch_size*C], original_shape)""" + return ( + rearrange( + latents, + "n c (h dh) (w dw) -> n (h w) (dh dw c)", + dh=patch_size, + dw=patch_size, + ), + latents.shape, + ) + + +def vae_unflatten(latents, shape, patch_size=2): + """Unpatchify: [N, num_patches, patch_size*patch_size*C] -> [N, C, H, W]""" + n, c, h, w = shape + return rearrange( + latents, + "n (h w) (dh dw c) -> n c (h dh) (w dw)", + dh=patch_size, + dw=patch_size, + c=c, + h=h // patch_size, + w=w // patch_size, + ) + + +def prepare_latent_image_ids(start_indices, height, width, patch_size=2, max_offset=0): + """Generate 3D positional IDs for image patches.""" + if isinstance(start_indices, list): + start_indices = torch.tensor(start_indices) + + batch_size = len(start_indices) + latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3) + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :] + ) + + h, w, ch = latent_image_ids.shape + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + + for i, start_idx in enumerate(start_indices): + latent_image_ids[i, :, :, 0] = start_idx + + return latent_image_ids.reshape(batch_size, h * w, ch).int() + + +def make_text_position_ids(valid_len, max_sequence_length, extra_padding=0): + """Generate 3D positional IDs for text tokens.""" + device = valid_len.device + valid_len = valid_len + extra_padding + B = valid_len.shape[0] + seq = ( + torch.arange(1, max_sequence_length + 1, device=device) + .unsqueeze(0) + .expand(B, -1) + ) + increment_then_repeat = torch.minimum(seq, valid_len.unsqueeze(1)) + pos_ids = torch.zeros((B, max_sequence_length, 3), device=device) + pos_ids[:, :, 0] = increment_then_repeat + return pos_ids.int() + + + +def time_shift(mu: float, sigma: float, t: Tensor) -> Tensor: + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list: + """Build a shifted cosine timestep schedule from t=1 (noise) to t=0 (clean).""" + timesteps = torch.linspace(1, 0, num_steps + 1) + if shift: + m = (max_shift - base_shift) / (4096 - 256) + b = base_shift - m * 256 + mu = m * image_seq_len + b + timesteps = time_shift(mu, 1.0, timesteps) + return timesteps.tolist() + +def get_low_step_schedule(num_steps: int) -> list: + """Build uniform spaced timestep schedule from t=1 (noise) to t=0 (clean) to match training.""" + return torch.linspace(1, 0, num_steps + 1).tolist() \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/flex2/__init__.py b/ai-toolkit/extensions_built_in/flex2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75aea7511e9d35e680b4c3eec94ee7a7800eb55e --- /dev/null +++ b/ai-toolkit/extensions_built_in/flex2/__init__.py @@ -0,0 +1,6 @@ +from .flex2 import Flex2 + +AI_TOOLKIT_MODELS = [ + # put a list of models here + Flex2 +] diff --git a/ai-toolkit/extensions_built_in/flex2/flex2.py b/ai-toolkit/extensions_built_in/flex2/flex2.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d73be6d88bc3c7467ea41f841a85e2aedd7ab1 --- /dev/null +++ b/ai-toolkit/extensions_built_in/flex2/flex2.py @@ -0,0 +1,527 @@ +import os +from typing import TYPE_CHECKING, List + +import torch +import torchvision +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from diffusers import FluxTransformer2DModel, AutoencoderKL +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.mask import generate_random_mask, random_dialate_mask +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import Flex2Pipeline +from einops import rearrange, repeat +import random +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + + +def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5): + if random.random() < p: + kernel_size = random.randint(min_kernel_size, max_kernel_size) + # make sure it is odd + if kernel_size % 2 == 0: + kernel_size += 1 + img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size) + return img + +class Flex2(BaseModel): + arch = "flex2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['FluxTransformer2DModel'] + + # for training, pass these as kwargs + self.invert_inpaint_mask_chance = model_config.model_kwargs.get('invert_inpaint_mask_chance', 0.0) + self.inpaint_dropout = model_config.model_kwargs.get('inpaint_dropout', 0.0) + self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0) + self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0) + self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False) + self.random_dialate_mask = model_config.model_kwargs.get('random_dialate_mask', False) + self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False) + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Flux2 model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + # this is the original path put in the model directory + # it is here because for finetuning we only save the transformer usually + # so we need this for the VAE, te, etc + base_model_path = self.model_config.name_or_path_original + + transformer_path = model_path + transformer_subfolder = 'transformer' + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading transformer") + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=transformer_subfolder, + torch_dtype=dtype, + ) + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + base_model_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + base_model_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + self.print_and_status_update("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + + self.noise_scheduler = Flex2.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: Flex2Pipeline = Flex2Pipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = Flex2.get_train_scheduler() + + pipeline: Flex2Pipeline = Flex2Pipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer) + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: Flex2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if gen_config.ctrl_img is None: + control_img = None + else: + control_img = Image.open(gen_config.ctrl_img) + if ".inpaint." not in gen_config.ctrl_img: + control_img = control_img.convert("RGB") + else: + # make sure it has an alpha + if control_img.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + control_image=control_img, + control_image_idx=gen_config.ctrl_idx, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + guidance_embedding_scale: float, + bypass_guidance_embedding: bool, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", + b=bs).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet_unwrapped.config.guidance_embeds: + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor( + guidance_embedding_scale, device=self.device_torch) + else: + guidance = torch.tensor( + [guidance_embedding_scale], device=self.device_torch) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) + + cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to( + self.device_torch, cast_dtype), + timestep=timestep / 1000, + encoder_hidden_states=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype), + pooled_projections=text_embeddings.pooled_embeds.to( + self.device_torch, cast_dtype), + txt_ids=txt_ids, + img_ids=img_ids, + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=self.vae.config.latent_channels + ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, + self.text_encoder, + prompt, + max_length=512, + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + # return from a weight if it has grad + return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: FluxTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + with torch.no_grad(): + # inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor + # 4th channel is the mask with 1 being keep area and 0 being area to inpaint. + # todo handle dropout on a batch item level, this frops out the entire batch + do_dropout = random.random() < self.inpaint_dropout if self.inpaint_dropout > 0.0 else False + # do random mask if we dont have one + inpaint_tensor = batch.inpaint_tensor + if inpaint_tensor is None and batch.mask_tensor is not None: + # we have a mask tensor, use it + inpaint_tensor = batch.mask_tensor + + if self.inpaint_random_chance > 0.0: + do_random = random.random() < self.inpaint_random_chance + if do_random: + # force a random tensor + inpaint_tensor = None + + if inpaint_tensor is None and not do_dropout and self.do_random_inpainting: + # generate a random one since we dont have one + # this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha + inpaint_tensor = 1 - generate_random_mask( + batch_size=latents.shape[0], + height=latents.shape[2], + width=latents.shape[3], + device=latents.device, + ).to(latents.device, latents.dtype) + if inpaint_tensor is not None and not do_dropout: + + if inpaint_tensor.shape[1] == 4: + # get just the mask + inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype) + elif inpaint_tensor.shape[1] == 3: + # rgb mask. Just get one channel + inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype) + # mask is 0-1 with 1 being inpaint area, we need to invert it for now, it is re inverted later + inpaint_tensor = 1 - inpaint_tensor + else: + inpainting_tensor_mask = inpaint_tensor + + # # use our batch latents so we cna avoid encoding again + inpainting_latent = batch.latents + + # resize the mask to match the new encoded size + inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear') + inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype) + + if self.random_blur_mask: + # blur the mask + # Give it a channel dim of 1 + if len(inpainting_tensor_mask.shape) == 3: + # if it is 3d, add a channel dim + inpainting_tensor_mask = inpainting_tensor_mask.unsqueeze(1) + # we are at latent size, so keep kernel smaller + inpainting_tensor_mask = random_blur( + inpainting_tensor_mask, + min_kernel_size=3, + max_kernel_size=8, + p=0.5 + ) + + do_mask_invert = False + if self.invert_inpaint_mask_chance > 0.0: + do_mask_invert = random.random() < self.invert_inpaint_mask_chance + if do_mask_invert: + # invert the mask + inpainting_tensor_mask = 1 - inpainting_tensor_mask + + # mask out the inpainting area, it is currently 0 for inpaint area, and 1 for keep area + # we are zeroing our the latents in the inpaint area not on the pixel space. + inpainting_latent = inpainting_latent * inpainting_tensor_mask + + # do the random dialation after the mask is applied so it does not match perfectly. + # this will make the model learn to prevent weird edges + if self.random_dialate_mask: + inpainting_tensor_mask = random_dialate_mask( + inpainting_tensor_mask, + max_percent=0.05 + ) + + # mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it. + inpainting_tensor_mask = 1 - inpainting_tensor_mask + # leave the mask as 0-1 and concat on channel of latents + inpainting_latent = torch.cat((inpainting_latent, inpainting_tensor_mask), dim=1) + else: + # we have iinpainting but didnt get a control. or we are doing a dropout + # the input needs to be all zeros for the latents and all 1s for the mask + inpainting_latent = torch.zeros_like(latents) + # add ones for the mask since we are technically inpainting everything + inpainting_latent = torch.cat((inpainting_latent, torch.ones_like(inpainting_latent[:, :1, :, :])), dim=1) + + control_tensor = batch.control_tensor + if control_tensor is None: + # concat random normal noise onto the latents + # check dimension, this is before they are rearranged + # it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging + ctrl = torch.zeros( + latents.shape[0], # bs + latents.shape[1], + latents.shape[2], + latents.shape[3], + device=latents.device, + dtype=latents.dtype + ) + # inpainting always comes first + ctrl = torch.cat((inpainting_latent, ctrl), dim=1) + latents = torch.cat((latents, ctrl), dim=1) + return latents.detach() + # if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w] + # if we have 1, it comes in like [bs, ch, h, w] + # stack out control tensors to be [bs, ch * num_control_images, h, w] + + control_tensor_list = [] + if len(control_tensor.shape) == 4: + control_tensor_list.append(control_tensor) + else: + num_control_images = control_tensor.shape[1] + # reshape + control_tensor = control_tensor.view( + control_tensor.shape[0], + control_tensor.shape[1] * control_tensor.shape[2], + control_tensor.shape[3], + control_tensor.shape[4] + ) + control_tensor_list = control_tensor.chunk(num_control_images, dim=1) + + do_dropout = random.random() < self.control_dropout if self.control_dropout > 0.0 else False + if do_dropout: + # dropout with zeros + control_latent = torch.zeros_like(batch.latents) + else: + # we only have one control so we randomly pick from this list + control_tensor = random.choice(control_tensor_list) + # it is 0-1 need to convert to -1 to 1 + control_tensor = control_tensor * 2 - 1 + + control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: + control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') + + # encode it + control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype) + + # inpainting always comes first + control_latent = torch.cat((inpainting_latent, control_latent), dim=1) + # concat it onto the latents + latents = torch.cat((latents, control_latent), dim=1) + return latents.detach() \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/flex2/pipeline.py b/ai-toolkit/extensions_built_in/flex2/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8661ff13e3b3a2f9afade97fbe0026880caac73f --- /dev/null +++ b/ai-toolkit/extensions_built_in/flex2/pipeline.py @@ -0,0 +1,348 @@ +from diffusers import FluxControlPipeline, FluxTransformer2DModel +from typing import Any, Callable, Dict, List, Optional, Union +import torch + +from diffusers.image_processor import PipelineImageInput +import numpy as np +from PIL import Image +import torch.nn.functional as F +from torchvision import transforms +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, XLA_AVAILABLE + + +class Flex2Pipeline(FluxControlPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + ): + super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + control_image: Optional[PipelineImageInput] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + control_image_idx: int = 0, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + # num_channels_latents = self.transformer.config.in_channels // 8 + num_channels_latents = 128 // 8 + + # pull mask off control image if there is one it is a pil image + mask = None + if control_image is not None and control_image.mode == "RGBA": + control_img_array = np.array(control_image) + mask = control_img_array[:, :, 3:4] + # scale it to 0 - 1 + mask = mask / 255.0 + # control image ideally would be a full image here + control_img_array = control_img_array[:, :, :3] + control_image = Image.fromarray(control_img_array.astype(np.uint8)) + + if control_image is not None: + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + num_control_channels = num_channels_latents + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if mask is not None: + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0) + # resize mask to match control image + mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False) + mask = mask.to(device) + # apply the mask to the control image so the inpaint latent area is 0 + # mask is currently 0 for inpaint area and 1 for image area + control_image = control_image * mask + # invert mask so it is 1 for inpaint area and 0 for image area + mask = 1 - mask + control_image = torch.cat([control_image, mask], dim=1) + num_control_channels += 1 + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_control_channels, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # make a blank control latent + control_image_list = [ + # impainting + torch.cat([torch.zeros_like(latents), torch.ones_like(latents[:, :, :4])], dim=2), + # control + torch.zeros_like(latents), + ] + if control_image is not None: + + control_image_list[control_image_idx] = control_image + + latent_model_input = torch.cat([latents] + control_image_list, dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/ai-toolkit/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..11fa8a9eef13c5546bc2de5723301b4d4838a0cc --- /dev/null +++ b/ai-toolkit/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -0,0 +1,235 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random +from toolkit.basic import value_map + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class ReferenceSliderConfig: + def __init__(self, **kwargs): + self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) + self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] + + +class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {})) + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.slider_config.datasets: + print(f" - Dataset: {dataset.pair_folder}") + config = { + 'path': dataset.pair_folder, + 'size': dataset.size, + 'default_prompt': dataset.target_class, + 'network_weight': dataset.network_weight, + 'pos_weight': dataset.pos_weight, + 'neg_weight': dataset.neg_weight, + 'pos_folder': dataset.pos_folder, + 'neg_folder': dataset.neg_folder, + } + image_dataset = PairedImageDataset(config) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.train_config.batch_size, + shuffle=True, + num_workers=2 + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + self.load_datasets() + + pass + + def hook_train_loop(self, batch): + with torch.no_grad(): + imgs, prompts, network_weights = batch + network_pos_weight, network_neg_weight = network_weights + + if isinstance(network_pos_weight, torch.Tensor): + network_pos_weight = network_pos_weight.item() + if isinstance(network_neg_weight, torch.Tensor): + network_neg_weight = network_neg_weight.item() + + # get an array of random floats between -weight_jitter and weight_jitter + loss_jitter_multiplier = 1.0 + weight_jitter = self.slider_config.weight_jitter + if weight_jitter > 0.0: + jitter_list = random.uniform(-weight_jitter, weight_jitter) + orig_network_pos_weight = network_pos_weight + network_pos_weight += jitter_list + network_neg_weight += (jitter_list * -1.0) + # penalize the loss for its distance from network_pos_weight + # a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter + # so the loss_jitter_multiplier needs to be 0.4 + loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0) + + + # if items in network_weight list are tensors, convert them to floats + + dtype = get_torch_dtype(self.train_config.dtype) + imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) + # split batched images in half so left is negative and right is positive + negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + + positive_latents = self.sd.encode_images(positive_images) + negative_latents = self.sd.encode_images(negative_images) + + height = positive_images.shape[2] + width = positive_images.shape[3] + batch_size = positive_images.shape[0] + + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise_positive = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + noise_negative = noise_positive.clone() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) + noise = torch.cat([noise_positive, noise_negative], dim=0) + timesteps = torch.cat([timesteps, timesteps], dim=0) + network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0] + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # if training text encoder enable grads, else do context of no grad + with torch.set_grad_enabled(self.train_config.train_text_encoder): + # fix issue with them being tuples sometimes + prompt_list = [] + for prompt in prompts: + if isinstance(prompt, tuple): + prompt = prompt[0] + prompt_list.append(prompt) + conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype) + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + + # if self.model_config.is_xl: + # # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop + # network_multiplier_list = network_multiplier + # noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) + # noise_list = torch.chunk(noise, 2, dim=0) + # timesteps_list = torch.chunk(timesteps, 2, dim=0) + # conditional_embeds_list = split_prompt_embeds(conditional_embeds) + # else: + network_multiplier_list = [network_multiplier] + noisy_latent_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditional_embeds_list = [conditional_embeds] + + losses = [] + # allow to chunk it out to save vram + for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip( + network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list + ): + with self.network: + assert self.network.is_active + + self.network.multiplier = network_multiplier + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() * loss_jitter_multiplier + + loss_float = loss.item() + losses.append(loss_float) + + # back propagate loss to free ram + loss.backward() + + # apply gradients + optimizer.step() + lr_scheduler.step() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': sum(losses) / len(losses) if len(losses) > 0 else 0.0} + ) + + return loss_dict + # end hook_train_loop diff --git a/ai-toolkit/extensions_built_in/image_reference_slider_trainer/__init__.py b/ai-toolkit/extensions_built_in/image_reference_slider_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a15f646bde32a68d194838c4c293619caa8bf93 --- /dev/null +++ b/ai-toolkit/extensions_built_in/image_reference_slider_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class ImageReferenceSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "image_reference_slider_trainer" + + # name is the name of the extension for printing + name = "Image Reference Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess + return ImageReferenceSliderTrainerProcess + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ImageReferenceSliderTrainer +] diff --git a/ai-toolkit/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml b/ai-toolkit/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b0f4734ae09fb7e942e33089014ffe59cfd7720 --- /dev/null +++ b/ai-toolkit/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/ai-toolkit/extensions_built_in/sd_trainer/DiffusionTrainer.py b/ai-toolkit/extensions_built_in/sd_trainer/DiffusionTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..51c081b3cff3cc2997dab65e22ed0785a9d01def --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -0,0 +1,373 @@ +from collections import OrderedDict +import os +import sqlite3 +import asyncio +import concurrent.futures +from extensions_built_in.sd_trainer.SDTrainer import SDTrainer +from typing import Literal, Optional +import threading +import time +import signal +from toolkit.basic import flush +from toolkit.print import print_acc + +AITK_Status = Literal["running", "stopped", "error", "completed"] + + +class DiffusionTrainer(SDTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(DiffusionTrainer, self).__init__(process_id, job, config, **kwargs) + self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") + self.job_id = os.environ.get("AITK_JOB_ID", None) + self.job_id = self.job_id.strip() if self.job_id is not None else None + self.is_ui_trainer = True + if not os.path.exists(self.sqlite_db_path): + self.is_ui_trainer = False + else: + print(f"Using SQLite database at {self.sqlite_db_path}") + if self.job_id is None: + self.is_ui_trainer = False + else: + print(f"Job ID: \"{self.job_id}\"") + + if self.is_ui_trainer: + self.is_stopping = False + # Create a thread pool for database operations + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Track all async tasks + self._async_tasks = [] + # Initialize the status + self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + # self.start_stop_watcher(interval_sec=2.0) + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if not self.is_ui_trainer: + return + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) + + def _run_async_operation(self, coro): + """Helper method to run an async coroutine and track the task.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a task and track it + if loop.is_running(): + task = asyncio.run_coroutine_threadsafe(coro, loop) + self._async_tasks.append(asyncio.wrap_future(task)) + else: + task = loop.create_task(coro) + self._async_tasks.append(task) + loop.run_until_complete(task) + + async def _execute_db_operation(self, operation_func): + """Execute a database operation in a separate thread with retry on lock.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.thread_pool, lambda: self._retry_db_operation(operation_func) + ) + + def _db_connect(self): + """Create a new connection for each operation to avoid locking.""" + conn = sqlite3.connect(self.sqlite_db_path, timeout=30.0) + conn.isolation_level = None # Enable autocommit mode + return conn + + def _retry_db_operation(self, operation_func, max_retries=3, base_delay=2.0): + """Retry a database operation with exponential backoff on lock errors.""" + last_error = None + for attempt in range(max_retries + 1): + try: + return operation_func() + except sqlite3.OperationalError as e: + if "database is locked" in str(e): + last_error = e + if attempt < max_retries: + delay = base_delay * (2 ** attempt) # 2s, 4s, 8s + print(f"[AITK] Database locked (attempt {attempt + 1}/{max_retries + 1}), retrying in {delay:.1f}s...") + time.sleep(delay) + else: + print(f"[AITK] Database locked after {max_retries + 1} attempts, giving up.") + else: + raise + raise last_error + + def should_stop(self): + if not self.is_ui_trainer: + return False + def _check_stop(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT stop FROM Job WHERE id = ?", (self.job_id,)) + stop = cursor.fetchone() + return False if stop is None else stop[0] == 1 + + return self._retry_db_operation(_check_stop) + + def should_return_to_queue(self): + if not self.is_ui_trainer: + return False + def _check_return_to_queue(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,)) + return_to_queue = cursor.fetchone() + return False if return_to_queue is None else return_to_queue[0] == 1 + + return self._retry_db_operation(_check_return_to_queue) + + def maybe_stop(self): + if not self.is_ui_trainer: + return + if self.should_stop(): + self._run_async_operation( + self._update_status("stopped", "Job stopped")) + self.is_stopping = True + raise Exception("Job stopped") + if self.should_return_to_queue(): + self._run_async_operation( + self._update_status("queued", "Job queued")) + self.is_stopping = True + raise Exception("Job returning to queue") + + def should_save(self): + if not self.is_ui_trainer: + return False + def _check_save(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT save_now FROM Job WHERE id = ?", (self.job_id,)) + save_now = cursor.fetchone() + return False if save_now is None else save_now[0] == 1 + + return self._retry_db_operation(_check_save) + + def maybe_save(self): + if not self.is_ui_trainer: + return + if self.should_save(): + self.update_db_key("save_now", 0) + if self.progress_bar is not None: + self.progress_bar.pause() + print_acc(f"\nSaving at step {self.step_num}") + # clear any grads + self.optimizer.zero_grad() + self.save(self.step_num) + self.ensure_params_requires_grad() + flush() + if self.progress_bar is not None: + self.progress_bar.unpause() + self.save(self.step_num) + + async def _update_key(self, key, value): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + # Convert the value to string if it's not already + if isinstance(value, str): + value_to_insert = value + else: + value_to_insert = str(value) + + # Use parameterized query for both the column name and value + update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" + cursor.execute( + update_query, (value_to_insert, self.job_id)) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_step(self): + """Non-blocking update of the step count.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_key("step", self.step_num)) + + def update_db_key(self, key, value): + """Non-blocking update a key in the database.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_key(key, value)) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): + if not self.accelerator.is_main_process or not self.is_ui_trainer: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + if info is not None: + cursor.execute( + "UPDATE Job SET status = ?, info = ? WHERE id = ?", + (status, info, self.job_id) + ) + else: + cursor.execute( + "UPDATE Job SET status = ? WHERE id = ?", + (status, self.job_id) + ) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_status(self, status: AITK_Status, info: Optional[str] = None): + """Non-blocking update of status.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_status(status, info)) + + async def wait_for_all_async(self): + """Wait for all tracked async operations to complete.""" + if not self._async_tasks: + return + + try: + await asyncio.gather(*self._async_tasks) + except Exception as e: + pass + finally: + # Clear the task list after completion + self._async_tasks.clear() + + def on_error(self, e: Exception): + super(DiffusionTrainer, self).on_error(e) + if self.is_ui_trainer: + try: + if self.accelerator.is_main_process and not self.is_stopping: + self.update_status("error", str(e)) + self.update_db_key("step", self.last_save_step) + asyncio.run(self.wait_for_all_async()) + except Exception as db_err: + print(f"[AITK] Warning: failed to update DB during error handling: {db_err}") + finally: + self.thread_pool.shutdown(wait=True) + + def handle_timing_print_hook(self, timing_dict): + if "train_loop" not in timing_dict: + print("train_loop not found in timing_dict", timing_dict) + return + seconds_per_iter = timing_dict["train_loop"] + # determine iter/sec or sec/iter + if seconds_per_iter < 1: + iters_per_sec = 1 / seconds_per_iter + self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") + else: + self.update_db_key( + "speed_string", f"{seconds_per_iter:.2f} sec/iter") + + def done_hook(self): + super(DiffusionTrainer, self).done_hook() + if self.is_ui_trainer: + self.update_status("completed", "Training completed") + # Wait for all async operations to finish before shutting down + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def end_step_hook(self): + super(DiffusionTrainer, self).end_step_hook() + if self.is_ui_trainer: + self.update_step() + self.maybe_stop() + self.maybe_save() + + def hook_before_model_load(self): + super().hook_before_model_load() + if self.is_ui_trainer: + self.maybe_stop() + self.update_status("running", "Loading model") + + def before_dataset_load(self): + super().before_dataset_load() + if self.is_ui_trainer: + self.maybe_stop() + self.update_status("running", "Loading dataset") + + def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.is_ui_trainer: + self.maybe_stop() + self.update_step() + self.update_status("running", "Training") + self.timer.add_after_print_hook(self.handle_timing_print_hook) + + def status_update_hook_func(self, string): + self.update_status("running", string) + + def hook_after_sd_init_before_load(self): + super().hook_after_sd_init_before_load() + if self.is_ui_trainer: + self.maybe_stop() + self.sd.add_status_update_hook(self.status_update_hook_func) + + def sample_step_hook(self, img_num, total_imgs): + super().sample_step_hook(img_num, total_imgs) + if self.is_ui_trainer: + self.maybe_stop() + self.update_status( + "running", f"Generating images - {img_num + 1}/{total_imgs}") + + def sample(self, step=None, is_first=False): + self.maybe_stop() + total_imgs = len(self.sample_config.prompts) + self.update_status("running", f"Generating images - 0/{total_imgs}") + super().sample(step, is_first) + self.maybe_stop() + self.update_status("running", "Training") + + def save(self, step=None): + self.maybe_stop() + self.update_status("running", "Saving model") + super().save(step) + self.maybe_stop() + self.update_status("running", "Training") diff --git a/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py b/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e7274da26467eeac03140dd538e1fb72765601 --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py @@ -0,0 +1,2186 @@ +import os +import random +from collections import OrderedDict +from typing import Union, Literal, List, Optional + +import numpy as np +from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel + +import torch.functional as F +from safetensors.torch import load_file +from torch.utils.data import DataLoader, ConcatDataset + +from toolkit import train_tools +from toolkit.basic import value_map, adain, get_mean_std +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.config_modules import GenerateImageConfig +from toolkit.data_loader import get_dataloader_datasets +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO +from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType +from toolkit.image_utils import show_tensors, show_latents +from toolkit.ip_adapter import IPAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.print import print_acc +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ + apply_learnable_snr_gos, LearnableSNRGamma +import gc +import torch +from jobs.process import BaseSDTrainProcess +from torchvision import transforms +from diffusers import EMAModel +import math +from toolkit.train_tools import precondition_model_outputs_flow_match +from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe +from toolkit.util.losses import wavelet_loss, stepped_loss +import torch.nn.functional as F +from toolkit.unloader import unload_text_encoder +from PIL import Image +from torchvision.transforms import functional as TF +from toolkit.basic import flush + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class SDTrainer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] + self.do_prior_prediction = False + self.do_long_prompts = False + self.do_guided_loss = False + self.taesd: Optional[AutoencoderTiny] = None + + self._clip_image_embeds_unconditional: Union[List[str], None] = None + self.negative_prompt_pool: Union[List[str], None] = None + self.batch_negative_prompt: Union[List[str], None] = None + + self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + + self.do_grad_scale = True + if self.is_fine_tuning and self.is_bfloat: + self.do_grad_scale = False + if self.adapter_config is not None: + if self.adapter_config.train: + self.do_grad_scale = False + + # if self.train_config.dtype in ["fp16", "float16"]: + # # patch the scaler to allow fp16 training + # org_unscale_grads = self.scaler._unscale_grads_ + # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + # return org_unscale_grads(optimizer, inv_scale, found_inf, True) + # self.scaler._unscale_grads_ = _unscale_grads_replacer + + self.cached_blank_embeds: Optional[PromptEmbeds] = None + self.cached_trigger_embeds: Optional[PromptEmbeds] = None + self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None + + self.dfe: Optional[DiffusionFeatureExtractor] = None + self.unconditional_embeds = None + + if self.train_config.diff_output_preservation: + if self.trigger_word is None: + raise ValueError("diff_output_preservation requires a trigger_word to be set") + if self.network_config is None: + raise ValueError("diff_output_preservation requires a network to be set") + if self.train_config.train_text_encoder: + raise ValueError("diff_output_preservation is not supported with train_text_encoder") + + if self.train_config.blank_prompt_preservation: + if self.network_config is None: + raise ValueError("blank_prompt_preservation requires a network to be set") + + if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation: + # always do a prior prediction when doing output preservation + self.do_prior_prediction = True + + # store the loss target for a batch so we can use it in a loss + self._guidance_loss_target_batch: float = 0.0 + if isinstance(self.train_config.guidance_loss_target, (int, float)): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target) + elif isinstance(self.train_config.guidance_loss_target, list): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) + else: + raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") + + + def before_model_load(self): + pass + + def cache_sample_prompts(self): + if self.train_config.disable_sampling: + return + if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0: + # cache all the samples + self.sd.sample_prompts_cache = [] + sample_folder = os.path.join(self.save_root, 'samples') + output_path = os.path.join(sample_folder, 'test.jpg') + for i in range(len(self.sample_config.prompts)): + sample_item = self.sample_config.samples[i] + prompt = self.sample_config.prompts[i] + + # needed so we can autoparse the prompt to handle flags + gen_img_config = GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + negative_prompt=sample_item.neg, + output_path=output_path, + ctrl_img=sample_item.ctrl_img, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, + ) + + has_control_images = False + if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None: + has_control_images = True + # see if we need to encode the control images + if self.sd.encode_control_in_text_embeddings and has_control_images: + + ctrl_img_list = [] + + if gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_img_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_img_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_img_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + + + positive = self.sd.encode_prompt( + gen_img_config.prompt, + control_images=ctrl_img + ).to('cpu') + negative = self.sd.encode_prompt( + gen_img_config.negative_prompt, + control_images=ctrl_img + ).to('cpu') + else: + positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') + negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') + + self.sd.sample_prompts_cache.append({ + 'conditional': positive, + 'unconditional': negative + }) + + + def before_dataset_load(self): + self.assistant_adapter = None + # get adapter assistant if one is set + if self.train_config.adapter_assist_name_or_path is not None: + adapter_path = self.train_config.adapter_assist_name_or_path + + if self.train_config.adapter_assist_type == "t2i": + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch) + elif self.train_config.adapter_assist_type == "control_net": + self.assistant_adapter = ControlNetModel.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + else: + raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") + + self.assistant_adapter.eval() + self.assistant_adapter.requires_grad_(False) + flush() + if self.train_config.train_turbo and self.train_config.show_turbo_outputs: + if self.model_config.is_xl: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + else: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch) + self.taesd.eval() + self.taesd.requires_grad_(False) + + def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to('cpu') + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + + kwargs['control_images'] = control_image + self.unconditional_embeds = self.sd.encode_prompt( + [self.train_config.unconditional_prompt], + long_prompts=self.do_long_prompts, + **kwargs + ).to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + + if self.train_config.do_prior_divergence: + self.do_prior_prediction = True + # move vae to device if we did not cache latents + if not self.is_latents_cached: + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + else: + # offload it. Already cached + self.sd.vae.to('cpu') + flush() + add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) + if self.adapter is not None: + self.adapter.to(self.device_torch) + + # check if we have regs and using adapter and caching clip embeddings + has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 + is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) + + if has_reg and is_caching_clip_embeddings: + # we need a list of unconditional clip image embeds from other datasets to handle regs + unconditional_clip_image_embeds = [] + datasets = get_dataloader_datasets(self.data_loader) + for i in range(len(datasets)): + unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache + + if len(unconditional_clip_image_embeds) == 0: + raise ValueError("No unconditional clip image embeds found. This should not happen") + + self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + + if self.train_config.negative_prompt is not None: + if os.path.exists(self.train_config.negative_prompt): + with open(self.train_config.negative_prompt, 'r') as f: + self.negative_prompt_pool = f.readlines() + # remove empty + self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] + else: + # single prompt + self.negative_prompt_pool = [self.train_config.negative_prompt] + + # handle unload text encoder + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + print_acc("Caching embeddings and unloading text encoder") + with torch.no_grad(): + if self.train_config.train_text_encoder: + raise ValueError("Cannot unload text encoder if training text encoder") + # cache embeddings + self.sd.text_encoder_to(self.device_torch) + encode_kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + encode_kwargs['control_images'] = control_image + self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) + if self.trigger_word is not None: + self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs) + if self.train_config.diff_output_preservation: + self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) + + self.cache_sample_prompts() + + print_acc("\n***** UNLOADING TEXT ENCODER *****") + if self.is_caching_text_embeddings: + print_acc("Embeddings cached to disk. We dont need the text encoder anymore") + else: + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("***********************************") + print_acc("") + + # unload the text encoder + if self.is_caching_text_embeddings: + unload_text_encoder(self.sd) + else: + # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. + # keep legacy usage for now. + self.sd.text_encoder_to("cpu") + flush() + + if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: + # make sure we have this if not unloading + self.cached_blank_embeds = self.sd.encode_prompt("").to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + + if self.train_config.diffusion_feature_extractor_path is not None: + vae = self.sd.vae + # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": + # vae = self.sd.vae + self.dfe = load_dfe( + self.train_config.diffusion_feature_extractor_path, + vae=vae, + sd=self.sd + ) + self.dfe.to(self.device_torch) + if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: + # must be set to train for gradient checkpointing to work + self.dfe.vision_encoder.train() + self.dfe.vision_encoder.gradient_checkpointing = True + elif hasattr(self.dfe, 'model') and self.train_config.gradient_checkpointing: + if hasattr(self.dfe.model, 'enable_gradient_checkpointing'): + self.dfe.model.train() + self.dfe.model.enable_gradient_checkpointing() + if hasattr(self.dfe.model, 'gradient_checkpointing_enable'): + self.dfe.model.train() + self.dfe.model.gradient_checkpointing_enable() + elif hasattr(self.dfe.model, 'gradient_checkpointing'): + self.dfe.model.train() + self.dfe.model.gradient_checkpointing = True + else: + print_acc("Warning: Could not enable gradient checkpointing on diffusion feature extractor model.") + else: + self.dfe.eval() + + # enable gradient checkpointing on the vae + if vae is not None and self.train_config.gradient_checkpointing: + try: + vae.enable_gradient_checkpointing() + vae.train() + except: + pass + + + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): + # to process turbo learning, we make one big step from our current timestep to the end + # we then denoise the prediction on that remaining step and target our loss to our target latents + # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so. + # needs to be done on each item in batch as they may all have different timesteps + batch_size = pred.shape[0] + pred_chunks = torch.chunk(pred, batch_size, dim=0) + noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0) + timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0) + latent_chunks = torch.chunk(batch.latents, batch_size, dim=0) + noise_chunks = torch.chunk(noise, batch_size, dim=0) + + with torch.no_grad(): + # set the timesteps to 1000 so we can capture them to calculate the sigmas + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach() + + train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach() + + # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step + self.sd.noise_scheduler.set_timesteps( + 1, + device=self.device_torch + ) + + denoised_pred_chunks = [] + target_pred_chunks = [] + + for i in range(batch_size): + pred_item = pred_chunks[i] + noisy_latents_item = noisy_latents_chunks[i] + timesteps_item = timesteps_chunks[i] + latents_item = latent_chunks[i] + noise_item = noise_chunks[i] + with torch.no_grad(): + timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] + single_step_timestep_schedule = [timesteps_item.squeeze().item()] + # extract the sigma idx for our midpoint timestep + sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) + + end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) + end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) + + # add noise to our target + + # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step + # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach() + self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach() + # set our single timstep + self.sd.noise_scheduler.timesteps = torch.from_numpy( + np.array(single_step_timestep_schedule, dtype=np.float32) + ).to(device=self.device_torch) + + # set the step index to None so it will be recalculated on first step + self.sd.noise_scheduler._step_index = None + + denoised_latent = self.sd.noise_scheduler.step( + pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False + )[0] + + residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( + self.train_config.dtype)) + # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) + denoised_latent = denoised_latent - residual_noise + + denoised_pred_chunks.append(denoised_latent) + + denoised_latents = torch.cat(denoised_pred_chunks, dim=0) + # set the scheduler back to the original timesteps + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + + output = denoised_latents / self.sd.vae.config['scaling_factor'] + output = self.sd.vae.decode(output).sample + + if self.train_config.show_turbo_outputs: + # since we are completely denoising, we can show them here + with torch.no_grad(): + show_tensors(output) + + # we return our big partial step denoised latents as our pred and our untouched latents as our target. + # you can do mse against the two here or run the denoised through the vae for pixel space loss against the + # input tensor images. + + return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + + # you can expand these in a child class to make customization easier + def calculate_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + prior_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target + is_reg = any(batch.get_is_reg_list()) + additional_loss = 0.0 + + prior_mask_multiplier = None + target_mask_multiplier = None + dtype = get_torch_dtype(self.train_config.dtype) + + has_mask = batch.mask_tensor is not None + + with torch.no_grad(): + loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) + + if self.train_config.match_noise_norm: + # match the norm of the noise + noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred = noise_pred * (noise_norm / noise_pred_norm) + + if self.train_config.pred_scaler != 1.0: + noise_pred = noise_pred * self.train_config.pred_scaler + + target = None + + if self.train_config.target_noise_multiplier != 1.0: + noise = noise * self.train_config.target_noise_multiplier + + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) + + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std + + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier + + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust + + eps = 1e-5 + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() + + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: + assert not self.train_config.train_turbo + with torch.no_grad(): + prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + lat_height = batch.latents.shape[3] + lat_width = batch.latents.shape[4] + else: + lat_height = batch.latents.shape[2] + lat_width = batch.latents.shape[3] + # resize to size of noise_pred + prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic') + # stack first channel to match channels of noise_pred + prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) + + if len(noise_pred.shape) == 5: + prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video + prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1) + + prior_mask_multiplier = 1.0 - prior_mask + + # scale so it is a mean of 1 + prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() + if hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: + target = (noise - batch.latents).detach() + else: + target = noise + elif prior_pred is not None and not self.train_config.do_prior_divergence: + assert not self.train_config.train_turbo + # matching adapter prediction + target = prior_pred + elif self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) + elif self.train_config.do_signal_amplification: + if not self.sd.is_flow_matching: + raise ValueError("Signal amplification is only supported for flow matching models") + with torch.no_grad(): + nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype) + nas = nas * self.train_config.signal_amplification_strength + while len(nas.shape) < len(noise.shape): + nas = nas.unsqueeze(-1) + aug = batch.latents * nas + target = noise - (batch.latents + aug) + target = target.detach() + elif hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + + elif self.sd.is_flow_matching: + # forward ODE + target = (noise - batch.latents).detach() + # reverse ODE + # target = (batch.latents - noise).detach() + else: + target = noise + + if self.dfe is not None: + if self.dfe.version == 1: + model = self.sd + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + self.sd.noise_scheduler._step_index = None + self.sd.noise_scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index] + sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) + sl = stepped_latents + if len(sl.shape) == 5: + # video B,C,T,H,W + sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = sl.shape + sl = sl.reshape(b * t, c, h, w) + pred_features = self.dfe(sl.float()) + with torch.no_grad(): + bl = batch.latents + bl = bl.to(self.sd.vae.device) + if len(bl.shape) == 5: + # video B,C,T,H,W + bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = bl.shape + bl = bl.reshape(b * t, c, h, w) + target_features = self.dfe(bl.float()) + # scale dfe so it is weaker at higher noise levels + dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) + + dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ + self.train_config.diffusion_feature_extractor_weight * dfe_scaler + additional_loss += dfe_loss.mean() + elif self.dfe.version == 2: + # version 2 + # do diffusion feature extraction on target + with torch.no_grad(): + rectified_flow_target = noise.float() - batch.latents.float() + target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + + # do diffusion feature extraction on prediction + pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) + + dfe_loss = 0.0 + for i in range(len(target_feature_list)): + dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") + + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 + elif self.dfe.version in [3, 4, 5, 6, 7, 8, 9, 10]: + dfe_loss = self.dfe( + noise=noise, + noise_pred=noise_pred, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + scheduler=self.sd.noise_scheduler + ) + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight + else: + raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") + + if self.train_config.do_guidance_loss: + with torch.no_grad(): + # we make cached blank prompt embeds that match the batch size + unconditional_embeds = concat_prompt_embeds( + [self.unconditional_embeds] * noisy_latents.shape[0], + ) + unconditional_target = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=timesteps, + conditional_embeds=unconditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + is_video = len(target.shape) == 5 + + if self.train_config.do_guidance_loss_cfg_zero: + # zero cfg + # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 + batch_size = target.shape[0] + positive_flat = target.view(batch_size, -1) + negative_flat = unconditional_target.view(batch_size, -1) + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + alpha = st_star + + alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + else: + alpha = 1.0 + + guidance_scale = self._guidance_loss_target_batch + if isinstance(guidance_scale, list): + guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) + guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) + + unconditional_target = unconditional_target * alpha + target = unconditional_target + guidance_scale * (target - unconditional_target) + + if self.train_config.do_differential_guidance: + with torch.no_grad(): + guidance_scale = self.train_config.differential_guidance_scale + target = noise_pred + guidance_scale * (target - noise_pred) + + if target is None: + target = noise + + pred = noise_pred + + if self.train_config.train_turbo: + pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) + + ignore_snr = False + + if loss_target == 'source' or loss_target == 'unaugmented': + assert not self.train_config.train_turbo + # ignore_snr = True + if batch.sigmas is None: + raise ValueError("Batch sigmas is None. This should not happen") + + # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 + denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents + weighing = batch.sigmas ** -2.0 + if loss_target == 'source': + # denoise the latent and compare to the latent in the batch + target = batch.latents + elif loss_target == 'unaugmented': + # we have to encode images into latents for now + # we also denoise as the unaugmented tensor is not a noisy diffirental + with torch.no_grad(): + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) + unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier + target = unaugmented_latents.detach() + + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = target # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # mse loss without reduction + loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) + loss = loss_per_element + else: + local_loss_scale = 1.0 + if self.train_config.t0_loss_target or self.train_config.do_fft_loss: + # do the loss on a stepped timestep 0 prediction + # doto handle doing priors, preservations, masking, etc + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + # min 0.001 + tv = torch.clamp(tv, min=0.001) + + # step latent, use here or with do_fft_loss + t0 = noisy_latents - tv * noise_pred + + if self.train_config.t0_loss_target: + # replace the loss targets and pred + target = batch.latents.detach() + pred = t0 + # handle velocity equiv loss if set. This scales t0 loss to match velocity of flowmatchhing loss + if self.train_config.t0_velocity_equiv_weight: + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + local_loss_scale = velocity_equiv_weight + + if self.train_config.do_fft_loss: + with torch.no_grad(): + target_mag = torch.fft.rfft2(batch.latents.to(t0.device).float(), norm="ortho").abs() + pred_mag = torch.fft.rfft2(t0.float(), norm="ortho").abs() + fft_loss = F.mse_loss(pred_mag, target_mag, reduction="none") + if self.train_config.do_fft_velocity_equiv_weight: + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + fft_loss = fft_loss * velocity_equiv_weight + additional_loss += fft_loss.mean() + if self.train_config.loss_type == "pseudo_huber": + diff = pred.float() - target.float() + c=0.01 + loss =(torch.sqrt(diff.pow(2) + c ** 2) - c) + elif self.train_config.loss_type == "mae": + loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + elif self.train_config.loss_type == "wavelet": + loss = wavelet_loss(pred, batch.latents, noise) + elif self.train_config.loss_type == "stepped": + loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler) + # the way this loss works, it is low, increase it to match predictable LR effects + loss = loss * 10.0 + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + loss = loss * local_loss_scale + + # apply model specific loss scaling + loss = self.sd.scale_loss(loss) + + do_weighted_timesteps = False + if self.sd.is_flow_matching: + if self.train_config.linear_timesteps or self.train_config.linear_timesteps2: + do_weighted_timesteps = True + if self.train_config.timestep_type == "weighted": + # use the noise scheduler to get the weights for the timesteps + do_weighted_timesteps = True + + # handle linear timesteps and only adjust the weight of the timesteps + if do_weighted_timesteps: + # calculate the weights for the timesteps + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( + timesteps, + v2=self.train_config.linear_timesteps2, + timestep_type=self.train_config.timestep_type + ).to(loss.device, dtype=loss.dtype) + if len(loss.shape) == 4: + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + elif len(loss.shape) == 5: + timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() + loss = loss * timestep_weight + + if self.train_config.do_prior_divergence and prior_pred is not None: + loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + + if self.train_config.train_turbo: + mask_multiplier = mask_multiplier[:, 3:, :, :] + # resize to the size of the loss + mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') + + # multiply by our mask + try: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video + mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1) + loss = loss * mask_multiplier + except Exception as e: + # todo handle mask with video models + print("Could not apply mask multiplier to loss") + print(e) + pass + + prior_loss = None + if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: + assert not self.train_config.train_turbo + if self.train_config.loss_type == "mae": + prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") + else: + prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") + + prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier + if torch.isnan(prior_loss).any(): + print_acc("Prior loss is nan") + prior_loss = None + else: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + prior_loss = prior_loss.mean([1, 2, 3, 4]) + else: + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss + # loss = loss + prior_loss + # loss = loss + prior_loss + if len(noise_pred.shape) == 5: + loss = loss.mean([1, 2, 3, 4]) + else: + loss = loss.mean([1, 2, 3]) + # apply loss multiplier before prior loss + # multiply by our mask + try: + loss = loss * loss_multiplier + except: + # todo handle mask with video models + pass + if prior_loss is not None: + loss = loss + prior_loss + + if not self.train_config.train_turbo: + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, + fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # check for audio loss + if batch.audio_pred is not None and batch.audio_target is not None: + audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean") + audio_loss = audio_loss * self.train_config.audio_loss_multiplier + loss = loss + audio_loss + + # check for additional losses + if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: + + loss = loss + self.adapter.additional_loss.mean() + self.adapter.additional_loss = None + + if self.train_config.target_norm_std: + # seperate out the batch and channels + pred_std = noise_pred.std([2, 3], keepdim=True) + norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() + loss = loss + norm_std_loss + + + loss = loss + additional_loss + + if self.train_config.max_loss_debug and self.train_config.max_loss is not None: + if loss.item() > self.train_config.max_loss: + print_acc(f"Loss {loss.item()} is greater than max loss {self.train_config.max_loss}. Clipping to max loss.") + print_acc(f"timesteps: {timesteps}") + + if self.train_config.max_loss is not None: + loss = torch.clamp(loss, max=self.train_config.max_loss) + + return loss + + def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): + return batch + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + loss = get_guidance_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + sd=self.sd, + unconditional_embeds=unconditional_embeds, + train_config=self.train_config, + **kwargs + ) + + return loss + + + # ------------------------------------------------------------------ + # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative + # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) + # This version avoids jvp / double-back-prop issues with Flash-Attention + # adapted from the work of lodestonerock + # ------------------------------------------------------------------ + def get_mean_flow_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + dtype = get_torch_dtype(self.train_config.dtype) + total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000 + base_eps = 1e-3 + min_time_gap = 1e-2 + + with torch.no_grad(): + num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps + batch_size = batch.latents.shape[0] + timestep_t_list = [] + timestep_r_list = [] + + for i in range(batch_size): + t1 = random.randint(0, num_train_timesteps - 1) + t2 = random.randint(0, num_train_timesteps - 1) + t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)] + t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)] + if (t_t - t_r).item() < min_time_gap * 1000: + scaled_time_gap = min_time_gap * 1000 + if t_t.item() + scaled_time_gap > 1000: + t_r = t_r - scaled_time_gap + else: + t_t = t_t + scaled_time_gap + timestep_t_list.append(t_t) + timestep_r_list.append(t_r) + + timesteps_t = torch.stack(timestep_t_list, dim=0).float() + timesteps_r = torch.stack(timestep_r_list, dim=0).float() + + t_frac = timesteps_t / total_steps # [0,1] + r_frac = timesteps_r / total_steps # [0,1] + + latents_clean = batch.latents.to(dtype) + noise_sample = noise.to(dtype) + + lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None] + + eps = base_eps + + # concatenate timesteps as input for u(z, r, t) + timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps + + # model predicts u(z, r, t) + u_pred = self.predict_noise( + noisy_latents=lerp_vector.to(dtype), + timesteps=timesteps_cat.to(dtype), + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + with torch.no_grad(): + t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0) + lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None] + timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps + + u_perturbed = self.predict_noise( + noisy_latents=lerp_perturbed.to(dtype), + timesteps=timesteps_cat_perturbed.to(dtype), + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + # compute du/dt via finite difference (detached) + du_dt = (u_perturbed - u_pred).detach() / eps + # du_dt = (u_perturbed - u_pred).detach() + du_dt = du_dt.to(dtype) + + + time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype) + time_gap.clamp(min=1e-4) + u_shifted = u_pred + time_gap * du_dt + # u_shifted = u_pred + du_dt / time_gap + # u_shifted = u_pred + + # a step is done like this: + # stepped_latent = model_input + (timestep_next - timestep) * model_output + + # flow target velocity + # v_target = (noise_sample - latents_clean) / time_gap + # flux predicts opposite of velocity, so we need to invert it + v_target = (latents_clean - noise_sample) / time_gap + + # compute loss + loss = torch.nn.functional.mse_loss( + u_shifted.float(), + v_target.float(), + reduction='none' + ) + + with torch.no_grad(): + pure_loss = loss.mean().detach() + pure_loss.requires_grad_(True) + + loss = loss.mean() + if loss.item() > 1e3: + pass + self.accelerator.backward(loss) + return pure_loss + + + + def get_prior_prediction( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + conditioned_prompts=None, + **kwargs + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + can_disable_adapter = False + was_adapter_active = False + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or + isinstance(self.adapter, ReferenceAdapter) or + (isinstance(self.adapter, CustomAdapter)) + ): + can_disable_adapter = True + was_adapter_active = self.adapter.is_active + self.adapter.is_active = False + + if self.train_config.unload_text_encoder and self.adapter is not None and not isinstance(self.adapter, CustomAdapter): + raise ValueError("Prior predictions currently do not support unloading text encoder with adapter") + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + embeds_to_use = conditional_embeds.clone().detach() + # handle clip vision adapter by removing triggers from prompt and replacing with the class name + if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: + prompt_list = batch.get_caption_list() + class_name = '' + + triggers = ['[trigger]', '[name]'] + remove_tokens = [] + + if self.embed_config is not None: + triggers.append(self.embed_config.trigger) + for i in range(1, self.embed_config.tokens): + remove_tokens.append(f"{self.embed_config.trigger}_{i}") + if self.embed_config.trigger_class_name is not None: + class_name = self.embed_config.trigger_class_name + + if self.adapter is not None: + triggers.append(self.adapter_config.trigger) + for i in range(1, self.adapter_config.num_tokens): + remove_tokens.append(f"{self.adapter_config.trigger}_{i}") + if self.adapter_config.trigger_class_name is not None: + class_name = self.adapter_config.trigger_class_name + + for idx, prompt in enumerate(prompt_list): + for remove_token in remove_tokens: + prompt = prompt.replace(remove_token, '') + for trigger in triggers: + prompt = prompt.replace(trigger, class_name) + prompt_list[idx] = prompt + + if batch.prompt_embeds is not None: + embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) + else: + prompt_kwargs = {} + if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: + prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + embeds_to_use = self.sd.encode_prompt( + prompt_list, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype, + **prompt_kwargs + ).detach() + + # dont use network on this + # self.network.multiplier = 0.0 + self.sd.unet.eval() + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2: + # we need to remove the image embeds from the prompt except for flux + embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() + end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens + embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.clone().detach() + unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch + + prior_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, + rescale_cfg=self.train_config.cfg_rescale, + batch=batch, + **pred_kwargs # adapter residuals in here + ) + if was_unet_training: + self.sd.unet.train() + prior_pred = prior_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: + del pred_kwargs['down_intrablock_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: + del pred_kwargs['mid_block_additional_residual'] + + if can_disable_adapter: + self.adapter.is_active = was_adapter_active + # restore network + # self.network.multiplier = network_weight_list + if self.network is not None: + self.network.is_active = was_network_active + return prior_pred + + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def end_of_training_loop(self): + pass + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + batch: Optional['DataLoaderBatchDTO'] = None, + is_primary_pred: bool = False, + **kwargs, + ): + dtype = get_torch_dtype(self.train_config.dtype) + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch + return self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, + detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, + bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, + batch=batch, + **kwargs + ) + + + def train_single_accumulation(self, batch: DataLoaderBatchDTO): + print_acc("[spock-debug] train_single_accumulation: ENTER", flush=True) + with torch.no_grad(): + self.timer.start('preprocess_batch') + print_acc("[spock-debug] train_single_accumulation: preprocess_batch", flush=True) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_raw(batch) + batch = self.preprocess_batch(batch) + print_acc("[spock-debug] train_single_accumulation: preprocess_batch DONE", flush=True) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_processed(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + # do prior pred if prior regularization batch + do_reg_prior = False + if any([batch.file_items[idx].prior_reg for idx in range(len(batch.file_items))]): + do_reg_prior = True + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False + loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + for idx, file_item in enumerate(batch.file_items): + if file_item.is_reg: + loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + is_reg = True + + adapter_images = None + sigmas = None + if has_adapter_img and (self.adapter or self.assistant_adapter): + with self.timer('get_adapter_images'): + # todo move this to data loader + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + # match in channels + if self.assistant_adapter is not None: + in_channels = self.assistant_adapter.config.in_channels + if adapter_images.shape[1] != in_channels: + # we need to match the channels + adapter_images = adapter_images[:, :in_channels, :, :] + else: + raise NotImplementedError("Adapter images now must be loaded with dataloader") + + clip_images = None + if has_clip_image: + with self.timer('get_clip_images'): + # todo move this to data loader + if batch.clip_image_tensor is not None: + clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() + + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + if batch.mask_tensor is not None and self.sd.do_masked_loss: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + if len(noisy_latents.shape) == 5: + # video B,C,T,H,W + h = noisy_latents.shape[3] + w = noisy_latents.shape[4] + else: + h = noisy_latents.shape[2] + w = noisy_latents.shape[3] + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(h, w) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + # make avg 1.0 + mask_multiplier = mask_multiplier / mask_multiplier.mean() + + def get_adapter_multiplier(): + if self.adapter and isinstance(self.adapter, T2IAdapter): + # training a t2i adapter, not using as assistant. + return 1.0 + elif match_adapter_assist: + # training a texture. We want it high + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + else: + # training with assistance, we want it low + # adapter_strength_min = 0.4 + # adapter_strength_max = 0.7 + adapter_strength_min = 0.5 + adapter_strength_max = 1.1 + + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale + + # flush() + with self.timer('grad_setup'): + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding is not None: + grad_on_text_encoder = True + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + grad_on_text_encoder = True + + if self.adapter_config and self.adapter_config.type == 'te_augmenter': + grad_on_text_encoder = True + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + + prompts_1 = conditioned_prompts + prompts_2 = None + if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: + prompts_1 = batch.get_caption_short_list() + prompts_2 = conditioned_prompts + + # make the batch splits + if self.train_config.single_item_batching: + if self.model_config.refiner_name_or_path is not None: + raise ValueError("Single item batching is not supported when training the refiner") + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in prompts_1] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + if clip_images is not None: + clip_images_list = torch.chunk(clip_images, batch_size, dim=0) + else: + clip_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) + if prompts_2 is None: + prompt_2_list = [None for _ in range(batch_size)] + else: + prompt_2_list = [[prompt] for prompt in prompts_2] + + else: + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [prompts_1] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + clip_images_list = [clip_images] + mask_multiplier_list = [mask_multiplier] + if prompts_2 is None: + prompt_2_list = [None] + else: + prompt_2_list = [prompts_2] + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + clip_images_list, + mask_multiplier_list, + prompt_2_list + ): + + # if self.train_config.negative_prompt is not None: + # # add negative prompt + # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in + # range(len(conditioned_prompts))] + # if prompt_2 is not None: + # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] + + with (network): + # encode clip adapter here so embeds are active for tokenizer + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('encode_clip_vision_embeds'): + if has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True + ) + else: + # just do a blank one + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ), + is_training=True, + has_been_preprocessed=True, + drop=True + ) + # it will be injected into the tokenizer when called + self.adapter(conditional_clip_embeds) + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg): + quad_count = random.randint(1, 4) + self.adapter.train() + self.adapter.trigger_pre_te( + tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + batch_tensor=batch.tensor if not is_reg else None, + batch_size=noisy_latents.shape[0] + ) + + with self.timer('encode_prompt'): + unconditional_embeds = None + prompt_kwargs = {} + if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: + prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + with torch.set_grad_enabled(False): + if batch.prompt_embeds is not None: + # use the cached embeds + conditional_embeds = batch.prompt_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + else: + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + if self.cached_trigger_embeds is not None and not is_reg: + embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + conditional_embeds = concat_prompt_embeds( + [embeds_to_use] * noisy_latents.shape[0] + ) + if self.train_config.do_cfg: + unconditional_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + unconditional_embeds = concat_prompt_embeds( + [unconditional_embeds] * noisy_latents.shape[0] + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + elif grad_on_text_encoder: + with torch.set_grad_enabled(True): + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + if self.sd.encode_control_in_text_embeddings and batch.control_tensor_list is not None: + prompt_kwargs['control_images'] = batch.control_tensor_list + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + if self.train_config.diff_output_preservation: + dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] + dop_prompts_2 = None + if prompt_2 is not None: + dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] + self.diff_output_preservation_embeds = self.sd.encode_prompt( + dop_prompts, dop_prompts_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + if self.train_config.do_cfg: + unconditional_embeds = unconditional_embeds.detach() + + if self.decorator: + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds + ) + if self.train_config.do_cfg: + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, + is_unconditional=True + ) + + # flush() + pred_kwargs = {} + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals + + if self.adapter and isinstance(self.adapter, IPAdapter): + with self.timer('encode_adapter_embeds'): + # number of images to do if doing a quad image + quad_count = random.randint(1, 4) + image_size = self.adapter.input_size + if has_clip_image_embeds: + # todo handle reg images better than this + if is_reg: + # get unconditional image embeds from cache + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + if self.train_config.do_cfg: + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + else: + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds_unconditional, + quad_count=quad_count + ) + elif is_reg: + # we will zero it out in the img embedder + clip_images = torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach() + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images, + drop=True, + is_training=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + elif has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + # do cfg on clip embeds to normalize the embeddings for when doing cfg + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + drop=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + else: + print_acc("No Clip Image") + print_acc([file_item.path for file_item in batch.file_items]) + raise ValueError("Could not find clip image") + + if not self.adapter_config.train_image_encoder: + # we are not training the image encoder, so we need to detach the embeds + conditional_clip_embeds = conditional_clip_embeds.detach() + if self.train_config.do_cfg: + unconditional_clip_embeds = unconditional_clip_embeds.detach() + + with self.timer('encode_adapter'): + self.adapter.train() + conditional_embeds = self.adapter( + conditional_embeds.detach(), + conditional_clip_embeds, + is_unconditional=False + ) + if self.train_config.do_cfg: + unconditional_embeds = self.adapter( + unconditional_embeds.detach(), + unconditional_clip_embeds, + is_unconditional=True + ) + else: + # wipe out unconsitional + self.adapter.last_unconditional = None + + if self.adapter and isinstance(self.adapter, ReferenceAdapter): + # pass in our scheduler + self.adapter.noise_scheduler = self.lr_scheduler + if has_clip_image or has_adapter_img: + img_to_use = clip_images if has_clip_image else adapter_images + # currently 0-1 needs to be -1 to 1 + reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) + self.adapter.set_reference_images(reference_images) + self.adapter.noise_scheduler = self.sd.noise_scheduler + elif is_reg: + self.adapter.set_blank_reference_images(noisy_latents.shape[0]) + else: + self.adapter.set_reference_images(None) + + prior_pred = None + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + do_correct_pred_norm_prior = self.train_config.correct_pred_norm + + do_guidance_prior = False + + if batch.unconditional_latents is not None: + # for this not that, we need a prior pred to normalize + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + if guidance_type == 'tnt': + do_guidance_prior = True + + if (( + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): + with self.timer('prior predict'): + prior_embeds_to_use = conditional_embeds + # use diff_output_preservation embeds if doing dfe + if self.train_config.diff_output_preservation: + prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + + if self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + prior_embeds_to_use = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=prior_embeds_to_use, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + unconditional_embeds=unconditional_embeds, + conditioned_prompts=conditioned_prompts + ) + if prior_pred is not None: + prior_pred = prior_pred.detach() + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']): + quad_count = random.randint(1, 4) + self.adapter.train() + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=conditional_embeds, + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + if self.train_config.do_cfg and unconditional_embeds is not None: + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=unconditional_embeds, + is_training=True, + has_been_preprocessed=True, + is_unconditional=True, + quad_count=quad_count + ) + + if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: + self.adapter.add_extra_values(batch.extra_values.detach()) + + if self.train_config.do_cfg: + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), + is_unconditional=True) + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if self.train_config.do_cfg: + raise ValueError("ControlNetModel is not supported with CFG") + with torch.set_grad_enabled(self.adapter is not None): + adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + # add_text_embeds is pooled_prompt_embeds for sdxl + added_cond_kwargs = {} + if self.sd.is_xl: + added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds + added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) + down_block_res_samples, mid_block_res_sample = adapter( + noisy_latents, + timesteps, + encoder_hidden_states=conditional_embeds.text_embeds, + controlnet_cond=adapter_images, + conditioning_scale=1.0, + guess_mode=False, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + pred_kwargs['down_block_additional_residuals'] = down_block_res_samples + pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list): + batch_size = noisy_latents.shape[0] + # update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1] + self._guidance_loss_target_batch = [ + random.uniform( + self.train_config.guidance_loss_target[0], + self.train_config.guidance_loss_target[1] + ) for _ in range(batch_size) + ] + + self.before_unet_predict() + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + with self.timer('condition_noisy_latents'): + # do it for the model + noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) + if self.adapter and isinstance(self.adapter, CustomAdapter): + noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + + if self.train_config.timestep_type == 'next_sample': + with self.timer('next_sample_step'): + with torch.no_grad(): + + stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] + stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] + stepped_timesteps = torch.stack(stepped_timesteps, dim=0) + + # do a sample at the current timestep and step it, then determine new noise + next_sample_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + stepped_latents = self.sd.step_scheduler( + next_sample_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) + # stepped latents is our new noisy latents. Now we need to determine noise in the current sample + noisy_latents = stepped_latents + original_samples = batch.latents.to(self.device_torch, dtype=dtype) + # todo calc next timestep, for now this may work as it + t_01 = (stepped_timesteps / 1000).to(original_samples.device) + if len(stepped_latents.shape) == 4: + t_01 = t_01.view(-1, 1, 1, 1) + elif len(stepped_latents.shape) == 5: + t_01 = t_01.view(-1, 1, 1, 1, 1) + else: + raise ValueError("Unknown stepped latents shape", stepped_latents.shape) + next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 + noise = next_sample_noise + timesteps = stepped_timesteps + # do a prior pred if we have an unconditional image, we will swap out the giadance later + if batch.unconditional_latents is not None or self.do_guided_loss: + # do guided loss + loss = self.get_guided_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + + elif self.train_config.loss_type == 'mean_flow': + loss = self.get_mean_flow_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + prior_pred=prior_pred, + ) + else: + with self.timer('predict_unet'): + print_acc("[spock-debug] predict_unet: CALLING model.forward() — first step, may compile/be slow", flush=True) + noise_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + is_primary_pred=True, + **pred_kwargs + ) + print_acc("[spock-debug] predict_unet: model.forward() RETURNED", flush=True) + self.after_unet_predict() + print_acc("[spock-debug] after_unet_predict() DONE", flush=True) + + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + prior_to_calculate_loss = prior_pred + # if we are doing diff_output_preservation and not noing inverted masked prior + # then we need to send none here so it will not target the prior + doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation + if doing_preservation and not do_inverted_masked_prior: + prior_to_calculate_loss = None + + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) + + if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: + with torch.no_grad(): + if self.train_config.diff_output_preservation: + preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + elif self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + preservation_embeds = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + preservation_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier + preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier + self.additional_logs['loss/normal'] = loss.item() + self.additional_logs['loss/preservation'] = preservation_loss.item() + loss = loss + preservation_loss + + # check if nan + if torch.isnan(loss): + print_acc("loss is nan") + loss = torch.zeros_like(loss).requires_grad_(True) + + with self.timer('backward'): + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + self.accelerator.backward(loss) + + return loss.detach() + # flush() + + def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): + if isinstance(batch, list): + batch_list = batch + else: + batch_list = [batch] + total_loss = None + self.optimizer.zero_grad() + for batch in batch_list: + if self.sd.is_multistage: + # handle multistage switching + if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries: + # iterate to make sure we only train trainable_multistage_boundaries + while True: + self.steps_this_boundary = 0 + self.current_boundary_index += 1 + if self.current_boundary_index >= len(self.sd.multistage_boundaries): + self.current_boundary_index = 0 + if self.current_boundary_index in self.sd.trainable_multistage_boundaries: + # if this boundary is trainable, we can stop looking + break + loss = self.train_single_accumulation(batch) + self.steps_this_boundary += 1 + if total_loss is None: + total_loss = loss + else: + total_loss += loss + if len(batch_list) > 1 and self.model_config.low_vram: + torch.cuda.empty_cache() + + + if not self.is_grad_accumulation_step: + # fix this for multi params + if self.train_config.optimizer != 'adafactor': + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # only step if we are not accumulating + with self.timer('optimizer_step'): + self.optimizer.step() + + self.optimizer.zero_grad(set_to_none=True) + if self.adapter and isinstance(self.adapter, CustomAdapter): + self.adapter.post_weight_update() + if self.ema is not None: + with self.timer('ema_update'): + self.ema.update() + else: + # gradient accumulation. Just a place for breakpoint + pass + + # TODO Should we only step scheduler on grad step? If so, need to recalculate last step + with self.timer('scheduler_step'): + self.lr_scheduler.step() + + if self.embedding is not None: + with self.timer('restore_embeddings'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('restore_adapter'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.adapter.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': (total_loss / len(batch_list)).item()} + ) + + self.end_of_training_loop() + + return loss_dict diff --git a/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py.bak.20260627-204002 b/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py.bak.20260627-204002 new file mode 100644 index 0000000000000000000000000000000000000000..b392246402805154d3409df3d032f0cc820dd07a --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py.bak.20260627-204002 @@ -0,0 +1,2180 @@ +import os +import random +from collections import OrderedDict +from typing import Union, Literal, List, Optional + +import numpy as np +from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel + +import torch.functional as F +from safetensors.torch import load_file +from torch.utils.data import DataLoader, ConcatDataset + +from toolkit import train_tools +from toolkit.basic import value_map, adain, get_mean_std +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.config_modules import GenerateImageConfig +from toolkit.data_loader import get_dataloader_datasets +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO +from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType +from toolkit.image_utils import show_tensors, show_latents +from toolkit.ip_adapter import IPAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.print import print_acc +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ + apply_learnable_snr_gos, LearnableSNRGamma +import gc +import torch +from jobs.process import BaseSDTrainProcess +from torchvision import transforms +from diffusers import EMAModel +import math +from toolkit.train_tools import precondition_model_outputs_flow_match +from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe +from toolkit.util.losses import wavelet_loss, stepped_loss +import torch.nn.functional as F +from toolkit.unloader import unload_text_encoder +from PIL import Image +from torchvision.transforms import functional as TF +from toolkit.basic import flush + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class SDTrainer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] + self.do_prior_prediction = False + self.do_long_prompts = False + self.do_guided_loss = False + self.taesd: Optional[AutoencoderTiny] = None + + self._clip_image_embeds_unconditional: Union[List[str], None] = None + self.negative_prompt_pool: Union[List[str], None] = None + self.batch_negative_prompt: Union[List[str], None] = None + + self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + + self.do_grad_scale = True + if self.is_fine_tuning and self.is_bfloat: + self.do_grad_scale = False + if self.adapter_config is not None: + if self.adapter_config.train: + self.do_grad_scale = False + + # if self.train_config.dtype in ["fp16", "float16"]: + # # patch the scaler to allow fp16 training + # org_unscale_grads = self.scaler._unscale_grads_ + # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + # return org_unscale_grads(optimizer, inv_scale, found_inf, True) + # self.scaler._unscale_grads_ = _unscale_grads_replacer + + self.cached_blank_embeds: Optional[PromptEmbeds] = None + self.cached_trigger_embeds: Optional[PromptEmbeds] = None + self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None + + self.dfe: Optional[DiffusionFeatureExtractor] = None + self.unconditional_embeds = None + + if self.train_config.diff_output_preservation: + if self.trigger_word is None: + raise ValueError("diff_output_preservation requires a trigger_word to be set") + if self.network_config is None: + raise ValueError("diff_output_preservation requires a network to be set") + if self.train_config.train_text_encoder: + raise ValueError("diff_output_preservation is not supported with train_text_encoder") + + if self.train_config.blank_prompt_preservation: + if self.network_config is None: + raise ValueError("blank_prompt_preservation requires a network to be set") + + if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation: + # always do a prior prediction when doing output preservation + self.do_prior_prediction = True + + # store the loss target for a batch so we can use it in a loss + self._guidance_loss_target_batch: float = 0.0 + if isinstance(self.train_config.guidance_loss_target, (int, float)): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target) + elif isinstance(self.train_config.guidance_loss_target, list): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) + else: + raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") + + + def before_model_load(self): + pass + + def cache_sample_prompts(self): + if self.train_config.disable_sampling: + return + if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0: + # cache all the samples + self.sd.sample_prompts_cache = [] + sample_folder = os.path.join(self.save_root, 'samples') + output_path = os.path.join(sample_folder, 'test.jpg') + for i in range(len(self.sample_config.prompts)): + sample_item = self.sample_config.samples[i] + prompt = self.sample_config.prompts[i] + + # needed so we can autoparse the prompt to handle flags + gen_img_config = GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + negative_prompt=sample_item.neg, + output_path=output_path, + ctrl_img=sample_item.ctrl_img, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, + ) + + has_control_images = False + if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None: + has_control_images = True + # see if we need to encode the control images + if self.sd.encode_control_in_text_embeddings and has_control_images: + + ctrl_img_list = [] + + if gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_img_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_img_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_img_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + + + positive = self.sd.encode_prompt( + gen_img_config.prompt, + control_images=ctrl_img + ).to('cpu') + negative = self.sd.encode_prompt( + gen_img_config.negative_prompt, + control_images=ctrl_img + ).to('cpu') + else: + positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') + negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') + + self.sd.sample_prompts_cache.append({ + 'conditional': positive, + 'unconditional': negative + }) + + + def before_dataset_load(self): + self.assistant_adapter = None + # get adapter assistant if one is set + if self.train_config.adapter_assist_name_or_path is not None: + adapter_path = self.train_config.adapter_assist_name_or_path + + if self.train_config.adapter_assist_type == "t2i": + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch) + elif self.train_config.adapter_assist_type == "control_net": + self.assistant_adapter = ControlNetModel.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + else: + raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") + + self.assistant_adapter.eval() + self.assistant_adapter.requires_grad_(False) + flush() + if self.train_config.train_turbo and self.train_config.show_turbo_outputs: + if self.model_config.is_xl: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + else: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch) + self.taesd.eval() + self.taesd.requires_grad_(False) + + def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to('cpu') + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + + kwargs['control_images'] = control_image + self.unconditional_embeds = self.sd.encode_prompt( + [self.train_config.unconditional_prompt], + long_prompts=self.do_long_prompts, + **kwargs + ).to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + + if self.train_config.do_prior_divergence: + self.do_prior_prediction = True + # move vae to device if we did not cache latents + if not self.is_latents_cached: + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + else: + # offload it. Already cached + self.sd.vae.to('cpu') + flush() + add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) + if self.adapter is not None: + self.adapter.to(self.device_torch) + + # check if we have regs and using adapter and caching clip embeddings + has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 + is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) + + if has_reg and is_caching_clip_embeddings: + # we need a list of unconditional clip image embeds from other datasets to handle regs + unconditional_clip_image_embeds = [] + datasets = get_dataloader_datasets(self.data_loader) + for i in range(len(datasets)): + unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache + + if len(unconditional_clip_image_embeds) == 0: + raise ValueError("No unconditional clip image embeds found. This should not happen") + + self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + + if self.train_config.negative_prompt is not None: + if os.path.exists(self.train_config.negative_prompt): + with open(self.train_config.negative_prompt, 'r') as f: + self.negative_prompt_pool = f.readlines() + # remove empty + self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] + else: + # single prompt + self.negative_prompt_pool = [self.train_config.negative_prompt] + + # handle unload text encoder + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + print_acc("Caching embeddings and unloading text encoder") + with torch.no_grad(): + if self.train_config.train_text_encoder: + raise ValueError("Cannot unload text encoder if training text encoder") + # cache embeddings + self.sd.text_encoder_to(self.device_torch) + encode_kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + encode_kwargs['control_images'] = control_image + self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) + if self.trigger_word is not None: + self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs) + if self.train_config.diff_output_preservation: + self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) + + self.cache_sample_prompts() + + print_acc("\n***** UNLOADING TEXT ENCODER *****") + if self.is_caching_text_embeddings: + print_acc("Embeddings cached to disk. We dont need the text encoder anymore") + else: + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("***********************************") + print_acc("") + + # unload the text encoder + if self.is_caching_text_embeddings: + unload_text_encoder(self.sd) + else: + # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. + # keep legacy usage for now. + self.sd.text_encoder_to("cpu") + flush() + + if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: + # make sure we have this if not unloading + self.cached_blank_embeds = self.sd.encode_prompt("").to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + + if self.train_config.diffusion_feature_extractor_path is not None: + vae = self.sd.vae + # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": + # vae = self.sd.vae + self.dfe = load_dfe( + self.train_config.diffusion_feature_extractor_path, + vae=vae, + sd=self.sd + ) + self.dfe.to(self.device_torch) + if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: + # must be set to train for gradient checkpointing to work + self.dfe.vision_encoder.train() + self.dfe.vision_encoder.gradient_checkpointing = True + elif hasattr(self.dfe, 'model') and self.train_config.gradient_checkpointing: + if hasattr(self.dfe.model, 'enable_gradient_checkpointing'): + self.dfe.model.train() + self.dfe.model.enable_gradient_checkpointing() + if hasattr(self.dfe.model, 'gradient_checkpointing_enable'): + self.dfe.model.train() + self.dfe.model.gradient_checkpointing_enable() + elif hasattr(self.dfe.model, 'gradient_checkpointing'): + self.dfe.model.train() + self.dfe.model.gradient_checkpointing = True + else: + print_acc("Warning: Could not enable gradient checkpointing on diffusion feature extractor model.") + else: + self.dfe.eval() + + # enable gradient checkpointing on the vae + if vae is not None and self.train_config.gradient_checkpointing: + try: + vae.enable_gradient_checkpointing() + vae.train() + except: + pass + + + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): + # to process turbo learning, we make one big step from our current timestep to the end + # we then denoise the prediction on that remaining step and target our loss to our target latents + # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so. + # needs to be done on each item in batch as they may all have different timesteps + batch_size = pred.shape[0] + pred_chunks = torch.chunk(pred, batch_size, dim=0) + noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0) + timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0) + latent_chunks = torch.chunk(batch.latents, batch_size, dim=0) + noise_chunks = torch.chunk(noise, batch_size, dim=0) + + with torch.no_grad(): + # set the timesteps to 1000 so we can capture them to calculate the sigmas + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach() + + train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach() + + # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step + self.sd.noise_scheduler.set_timesteps( + 1, + device=self.device_torch + ) + + denoised_pred_chunks = [] + target_pred_chunks = [] + + for i in range(batch_size): + pred_item = pred_chunks[i] + noisy_latents_item = noisy_latents_chunks[i] + timesteps_item = timesteps_chunks[i] + latents_item = latent_chunks[i] + noise_item = noise_chunks[i] + with torch.no_grad(): + timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] + single_step_timestep_schedule = [timesteps_item.squeeze().item()] + # extract the sigma idx for our midpoint timestep + sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) + + end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) + end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) + + # add noise to our target + + # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step + # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach() + self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach() + # set our single timstep + self.sd.noise_scheduler.timesteps = torch.from_numpy( + np.array(single_step_timestep_schedule, dtype=np.float32) + ).to(device=self.device_torch) + + # set the step index to None so it will be recalculated on first step + self.sd.noise_scheduler._step_index = None + + denoised_latent = self.sd.noise_scheduler.step( + pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False + )[0] + + residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( + self.train_config.dtype)) + # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) + denoised_latent = denoised_latent - residual_noise + + denoised_pred_chunks.append(denoised_latent) + + denoised_latents = torch.cat(denoised_pred_chunks, dim=0) + # set the scheduler back to the original timesteps + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + + output = denoised_latents / self.sd.vae.config['scaling_factor'] + output = self.sd.vae.decode(output).sample + + if self.train_config.show_turbo_outputs: + # since we are completely denoising, we can show them here + with torch.no_grad(): + show_tensors(output) + + # we return our big partial step denoised latents as our pred and our untouched latents as our target. + # you can do mse against the two here or run the denoised through the vae for pixel space loss against the + # input tensor images. + + return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + + # you can expand these in a child class to make customization easier + def calculate_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + prior_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target + is_reg = any(batch.get_is_reg_list()) + additional_loss = 0.0 + + prior_mask_multiplier = None + target_mask_multiplier = None + dtype = get_torch_dtype(self.train_config.dtype) + + has_mask = batch.mask_tensor is not None + + with torch.no_grad(): + loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) + + if self.train_config.match_noise_norm: + # match the norm of the noise + noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred = noise_pred * (noise_norm / noise_pred_norm) + + if self.train_config.pred_scaler != 1.0: + noise_pred = noise_pred * self.train_config.pred_scaler + + target = None + + if self.train_config.target_noise_multiplier != 1.0: + noise = noise * self.train_config.target_noise_multiplier + + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) + + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std + + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier + + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust + + eps = 1e-5 + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() + + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: + assert not self.train_config.train_turbo + with torch.no_grad(): + prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + lat_height = batch.latents.shape[3] + lat_width = batch.latents.shape[4] + else: + lat_height = batch.latents.shape[2] + lat_width = batch.latents.shape[3] + # resize to size of noise_pred + prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic') + # stack first channel to match channels of noise_pred + prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) + + if len(noise_pred.shape) == 5: + prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video + prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1) + + prior_mask_multiplier = 1.0 - prior_mask + + # scale so it is a mean of 1 + prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() + if hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: + target = (noise - batch.latents).detach() + else: + target = noise + elif prior_pred is not None and not self.train_config.do_prior_divergence: + assert not self.train_config.train_turbo + # matching adapter prediction + target = prior_pred + elif self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) + elif self.train_config.do_signal_amplification: + if not self.sd.is_flow_matching: + raise ValueError("Signal amplification is only supported for flow matching models") + with torch.no_grad(): + nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype) + nas = nas * self.train_config.signal_amplification_strength + while len(nas.shape) < len(noise.shape): + nas = nas.unsqueeze(-1) + aug = batch.latents * nas + target = noise - (batch.latents + aug) + target = target.detach() + elif hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + + elif self.sd.is_flow_matching: + # forward ODE + target = (noise - batch.latents).detach() + # reverse ODE + # target = (batch.latents - noise).detach() + else: + target = noise + + if self.dfe is not None: + if self.dfe.version == 1: + model = self.sd + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + self.sd.noise_scheduler._step_index = None + self.sd.noise_scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index] + sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) + sl = stepped_latents + if len(sl.shape) == 5: + # video B,C,T,H,W + sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = sl.shape + sl = sl.reshape(b * t, c, h, w) + pred_features = self.dfe(sl.float()) + with torch.no_grad(): + bl = batch.latents + bl = bl.to(self.sd.vae.device) + if len(bl.shape) == 5: + # video B,C,T,H,W + bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = bl.shape + bl = bl.reshape(b * t, c, h, w) + target_features = self.dfe(bl.float()) + # scale dfe so it is weaker at higher noise levels + dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) + + dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ + self.train_config.diffusion_feature_extractor_weight * dfe_scaler + additional_loss += dfe_loss.mean() + elif self.dfe.version == 2: + # version 2 + # do diffusion feature extraction on target + with torch.no_grad(): + rectified_flow_target = noise.float() - batch.latents.float() + target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + + # do diffusion feature extraction on prediction + pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) + + dfe_loss = 0.0 + for i in range(len(target_feature_list)): + dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") + + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 + elif self.dfe.version in [3, 4, 5, 6, 7, 8, 9, 10]: + dfe_loss = self.dfe( + noise=noise, + noise_pred=noise_pred, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + scheduler=self.sd.noise_scheduler + ) + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight + else: + raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") + + if self.train_config.do_guidance_loss: + with torch.no_grad(): + # we make cached blank prompt embeds that match the batch size + unconditional_embeds = concat_prompt_embeds( + [self.unconditional_embeds] * noisy_latents.shape[0], + ) + unconditional_target = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=timesteps, + conditional_embeds=unconditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + is_video = len(target.shape) == 5 + + if self.train_config.do_guidance_loss_cfg_zero: + # zero cfg + # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 + batch_size = target.shape[0] + positive_flat = target.view(batch_size, -1) + negative_flat = unconditional_target.view(batch_size, -1) + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + alpha = st_star + + alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + else: + alpha = 1.0 + + guidance_scale = self._guidance_loss_target_batch + if isinstance(guidance_scale, list): + guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) + guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) + + unconditional_target = unconditional_target * alpha + target = unconditional_target + guidance_scale * (target - unconditional_target) + + if self.train_config.do_differential_guidance: + with torch.no_grad(): + guidance_scale = self.train_config.differential_guidance_scale + target = noise_pred + guidance_scale * (target - noise_pred) + + if target is None: + target = noise + + pred = noise_pred + + if self.train_config.train_turbo: + pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) + + ignore_snr = False + + if loss_target == 'source' or loss_target == 'unaugmented': + assert not self.train_config.train_turbo + # ignore_snr = True + if batch.sigmas is None: + raise ValueError("Batch sigmas is None. This should not happen") + + # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 + denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents + weighing = batch.sigmas ** -2.0 + if loss_target == 'source': + # denoise the latent and compare to the latent in the batch + target = batch.latents + elif loss_target == 'unaugmented': + # we have to encode images into latents for now + # we also denoise as the unaugmented tensor is not a noisy diffirental + with torch.no_grad(): + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) + unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier + target = unaugmented_latents.detach() + + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = target # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # mse loss without reduction + loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) + loss = loss_per_element + else: + local_loss_scale = 1.0 + if self.train_config.t0_loss_target or self.train_config.do_fft_loss: + # do the loss on a stepped timestep 0 prediction + # doto handle doing priors, preservations, masking, etc + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + # min 0.001 + tv = torch.clamp(tv, min=0.001) + + # step latent, use here or with do_fft_loss + t0 = noisy_latents - tv * noise_pred + + if self.train_config.t0_loss_target: + # replace the loss targets and pred + target = batch.latents.detach() + pred = t0 + # handle velocity equiv loss if set. This scales t0 loss to match velocity of flowmatchhing loss + if self.train_config.t0_velocity_equiv_weight: + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + local_loss_scale = velocity_equiv_weight + + if self.train_config.do_fft_loss: + with torch.no_grad(): + target_mag = torch.fft.rfft2(batch.latents.to(t0.device).float(), norm="ortho").abs() + pred_mag = torch.fft.rfft2(t0.float(), norm="ortho").abs() + fft_loss = F.mse_loss(pred_mag, target_mag, reduction="none") + if self.train_config.do_fft_velocity_equiv_weight: + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + fft_loss = fft_loss * velocity_equiv_weight + additional_loss += fft_loss.mean() + if self.train_config.loss_type == "pseudo_huber": + diff = pred.float() - target.float() + c=0.01 + loss =(torch.sqrt(diff.pow(2) + c ** 2) - c) + elif self.train_config.loss_type == "mae": + loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + elif self.train_config.loss_type == "wavelet": + loss = wavelet_loss(pred, batch.latents, noise) + elif self.train_config.loss_type == "stepped": + loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler) + # the way this loss works, it is low, increase it to match predictable LR effects + loss = loss * 10.0 + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + loss = loss * local_loss_scale + + # apply model specific loss scaling + loss = self.sd.scale_loss(loss) + + do_weighted_timesteps = False + if self.sd.is_flow_matching: + if self.train_config.linear_timesteps or self.train_config.linear_timesteps2: + do_weighted_timesteps = True + if self.train_config.timestep_type == "weighted": + # use the noise scheduler to get the weights for the timesteps + do_weighted_timesteps = True + + # handle linear timesteps and only adjust the weight of the timesteps + if do_weighted_timesteps: + # calculate the weights for the timesteps + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( + timesteps, + v2=self.train_config.linear_timesteps2, + timestep_type=self.train_config.timestep_type + ).to(loss.device, dtype=loss.dtype) + if len(loss.shape) == 4: + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + elif len(loss.shape) == 5: + timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() + loss = loss * timestep_weight + + if self.train_config.do_prior_divergence and prior_pred is not None: + loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + + if self.train_config.train_turbo: + mask_multiplier = mask_multiplier[:, 3:, :, :] + # resize to the size of the loss + mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') + + # multiply by our mask + try: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video + mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1) + loss = loss * mask_multiplier + except Exception as e: + # todo handle mask with video models + print("Could not apply mask multiplier to loss") + print(e) + pass + + prior_loss = None + if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: + assert not self.train_config.train_turbo + if self.train_config.loss_type == "mae": + prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") + else: + prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") + + prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier + if torch.isnan(prior_loss).any(): + print_acc("Prior loss is nan") + prior_loss = None + else: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + prior_loss = prior_loss.mean([1, 2, 3, 4]) + else: + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss + # loss = loss + prior_loss + # loss = loss + prior_loss + if len(noise_pred.shape) == 5: + loss = loss.mean([1, 2, 3, 4]) + else: + loss = loss.mean([1, 2, 3]) + # apply loss multiplier before prior loss + # multiply by our mask + try: + loss = loss * loss_multiplier + except: + # todo handle mask with video models + pass + if prior_loss is not None: + loss = loss + prior_loss + + if not self.train_config.train_turbo: + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, + fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # check for audio loss + if batch.audio_pred is not None and batch.audio_target is not None: + audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean") + audio_loss = audio_loss * self.train_config.audio_loss_multiplier + loss = loss + audio_loss + + # check for additional losses + if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: + + loss = loss + self.adapter.additional_loss.mean() + self.adapter.additional_loss = None + + if self.train_config.target_norm_std: + # seperate out the batch and channels + pred_std = noise_pred.std([2, 3], keepdim=True) + norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() + loss = loss + norm_std_loss + + + loss = loss + additional_loss + + if self.train_config.max_loss_debug and self.train_config.max_loss is not None: + if loss.item() > self.train_config.max_loss: + print_acc(f"Loss {loss.item()} is greater than max loss {self.train_config.max_loss}. Clipping to max loss.") + print_acc(f"timesteps: {timesteps}") + + if self.train_config.max_loss is not None: + loss = torch.clamp(loss, max=self.train_config.max_loss) + + return loss + + def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): + return batch + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + loss = get_guidance_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + sd=self.sd, + unconditional_embeds=unconditional_embeds, + train_config=self.train_config, + **kwargs + ) + + return loss + + + # ------------------------------------------------------------------ + # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative + # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) + # This version avoids jvp / double-back-prop issues with Flash-Attention + # adapted from the work of lodestonerock + # ------------------------------------------------------------------ + def get_mean_flow_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + dtype = get_torch_dtype(self.train_config.dtype) + total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000 + base_eps = 1e-3 + min_time_gap = 1e-2 + + with torch.no_grad(): + num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps + batch_size = batch.latents.shape[0] + timestep_t_list = [] + timestep_r_list = [] + + for i in range(batch_size): + t1 = random.randint(0, num_train_timesteps - 1) + t2 = random.randint(0, num_train_timesteps - 1) + t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)] + t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)] + if (t_t - t_r).item() < min_time_gap * 1000: + scaled_time_gap = min_time_gap * 1000 + if t_t.item() + scaled_time_gap > 1000: + t_r = t_r - scaled_time_gap + else: + t_t = t_t + scaled_time_gap + timestep_t_list.append(t_t) + timestep_r_list.append(t_r) + + timesteps_t = torch.stack(timestep_t_list, dim=0).float() + timesteps_r = torch.stack(timestep_r_list, dim=0).float() + + t_frac = timesteps_t / total_steps # [0,1] + r_frac = timesteps_r / total_steps # [0,1] + + latents_clean = batch.latents.to(dtype) + noise_sample = noise.to(dtype) + + lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None] + + eps = base_eps + + # concatenate timesteps as input for u(z, r, t) + timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps + + # model predicts u(z, r, t) + u_pred = self.predict_noise( + noisy_latents=lerp_vector.to(dtype), + timesteps=timesteps_cat.to(dtype), + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + with torch.no_grad(): + t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0) + lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None] + timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps + + u_perturbed = self.predict_noise( + noisy_latents=lerp_perturbed.to(dtype), + timesteps=timesteps_cat_perturbed.to(dtype), + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + # compute du/dt via finite difference (detached) + du_dt = (u_perturbed - u_pred).detach() / eps + # du_dt = (u_perturbed - u_pred).detach() + du_dt = du_dt.to(dtype) + + + time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype) + time_gap.clamp(min=1e-4) + u_shifted = u_pred + time_gap * du_dt + # u_shifted = u_pred + du_dt / time_gap + # u_shifted = u_pred + + # a step is done like this: + # stepped_latent = model_input + (timestep_next - timestep) * model_output + + # flow target velocity + # v_target = (noise_sample - latents_clean) / time_gap + # flux predicts opposite of velocity, so we need to invert it + v_target = (latents_clean - noise_sample) / time_gap + + # compute loss + loss = torch.nn.functional.mse_loss( + u_shifted.float(), + v_target.float(), + reduction='none' + ) + + with torch.no_grad(): + pure_loss = loss.mean().detach() + pure_loss.requires_grad_(True) + + loss = loss.mean() + if loss.item() > 1e3: + pass + self.accelerator.backward(loss) + return pure_loss + + + + def get_prior_prediction( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + conditioned_prompts=None, + **kwargs + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + can_disable_adapter = False + was_adapter_active = False + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or + isinstance(self.adapter, ReferenceAdapter) or + (isinstance(self.adapter, CustomAdapter)) + ): + can_disable_adapter = True + was_adapter_active = self.adapter.is_active + self.adapter.is_active = False + + if self.train_config.unload_text_encoder and self.adapter is not None and not isinstance(self.adapter, CustomAdapter): + raise ValueError("Prior predictions currently do not support unloading text encoder with adapter") + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + embeds_to_use = conditional_embeds.clone().detach() + # handle clip vision adapter by removing triggers from prompt and replacing with the class name + if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: + prompt_list = batch.get_caption_list() + class_name = '' + + triggers = ['[trigger]', '[name]'] + remove_tokens = [] + + if self.embed_config is not None: + triggers.append(self.embed_config.trigger) + for i in range(1, self.embed_config.tokens): + remove_tokens.append(f"{self.embed_config.trigger}_{i}") + if self.embed_config.trigger_class_name is not None: + class_name = self.embed_config.trigger_class_name + + if self.adapter is not None: + triggers.append(self.adapter_config.trigger) + for i in range(1, self.adapter_config.num_tokens): + remove_tokens.append(f"{self.adapter_config.trigger}_{i}") + if self.adapter_config.trigger_class_name is not None: + class_name = self.adapter_config.trigger_class_name + + for idx, prompt in enumerate(prompt_list): + for remove_token in remove_tokens: + prompt = prompt.replace(remove_token, '') + for trigger in triggers: + prompt = prompt.replace(trigger, class_name) + prompt_list[idx] = prompt + + if batch.prompt_embeds is not None: + embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) + else: + prompt_kwargs = {} + if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: + prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + embeds_to_use = self.sd.encode_prompt( + prompt_list, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype, + **prompt_kwargs + ).detach() + + # dont use network on this + # self.network.multiplier = 0.0 + self.sd.unet.eval() + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2: + # we need to remove the image embeds from the prompt except for flux + embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() + end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens + embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.clone().detach() + unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch + + prior_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, + rescale_cfg=self.train_config.cfg_rescale, + batch=batch, + **pred_kwargs # adapter residuals in here + ) + if was_unet_training: + self.sd.unet.train() + prior_pred = prior_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: + del pred_kwargs['down_intrablock_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: + del pred_kwargs['mid_block_additional_residual'] + + if can_disable_adapter: + self.adapter.is_active = was_adapter_active + # restore network + # self.network.multiplier = network_weight_list + if self.network is not None: + self.network.is_active = was_network_active + return prior_pred + + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def end_of_training_loop(self): + pass + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + batch: Optional['DataLoaderBatchDTO'] = None, + is_primary_pred: bool = False, + **kwargs, + ): + dtype = get_torch_dtype(self.train_config.dtype) + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch + return self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, + detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, + bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, + batch=batch, + **kwargs + ) + + + def train_single_accumulation(self, batch: DataLoaderBatchDTO): + with torch.no_grad(): + self.timer.start('preprocess_batch') + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_raw(batch) + batch = self.preprocess_batch(batch) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_processed(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + # do prior pred if prior regularization batch + do_reg_prior = False + if any([batch.file_items[idx].prior_reg for idx in range(len(batch.file_items))]): + do_reg_prior = True + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False + loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + for idx, file_item in enumerate(batch.file_items): + if file_item.is_reg: + loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + is_reg = True + + adapter_images = None + sigmas = None + if has_adapter_img and (self.adapter or self.assistant_adapter): + with self.timer('get_adapter_images'): + # todo move this to data loader + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + # match in channels + if self.assistant_adapter is not None: + in_channels = self.assistant_adapter.config.in_channels + if adapter_images.shape[1] != in_channels: + # we need to match the channels + adapter_images = adapter_images[:, :in_channels, :, :] + else: + raise NotImplementedError("Adapter images now must be loaded with dataloader") + + clip_images = None + if has_clip_image: + with self.timer('get_clip_images'): + # todo move this to data loader + if batch.clip_image_tensor is not None: + clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() + + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + if batch.mask_tensor is not None and self.sd.do_masked_loss: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + if len(noisy_latents.shape) == 5: + # video B,C,T,H,W + h = noisy_latents.shape[3] + w = noisy_latents.shape[4] + else: + h = noisy_latents.shape[2] + w = noisy_latents.shape[3] + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(h, w) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + # make avg 1.0 + mask_multiplier = mask_multiplier / mask_multiplier.mean() + + def get_adapter_multiplier(): + if self.adapter and isinstance(self.adapter, T2IAdapter): + # training a t2i adapter, not using as assistant. + return 1.0 + elif match_adapter_assist: + # training a texture. We want it high + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + else: + # training with assistance, we want it low + # adapter_strength_min = 0.4 + # adapter_strength_max = 0.7 + adapter_strength_min = 0.5 + adapter_strength_max = 1.1 + + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale + + # flush() + with self.timer('grad_setup'): + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding is not None: + grad_on_text_encoder = True + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + grad_on_text_encoder = True + + if self.adapter_config and self.adapter_config.type == 'te_augmenter': + grad_on_text_encoder = True + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + + prompts_1 = conditioned_prompts + prompts_2 = None + if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: + prompts_1 = batch.get_caption_short_list() + prompts_2 = conditioned_prompts + + # make the batch splits + if self.train_config.single_item_batching: + if self.model_config.refiner_name_or_path is not None: + raise ValueError("Single item batching is not supported when training the refiner") + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in prompts_1] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + if clip_images is not None: + clip_images_list = torch.chunk(clip_images, batch_size, dim=0) + else: + clip_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) + if prompts_2 is None: + prompt_2_list = [None for _ in range(batch_size)] + else: + prompt_2_list = [[prompt] for prompt in prompts_2] + + else: + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [prompts_1] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + clip_images_list = [clip_images] + mask_multiplier_list = [mask_multiplier] + if prompts_2 is None: + prompt_2_list = [None] + else: + prompt_2_list = [prompts_2] + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + clip_images_list, + mask_multiplier_list, + prompt_2_list + ): + + # if self.train_config.negative_prompt is not None: + # # add negative prompt + # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in + # range(len(conditioned_prompts))] + # if prompt_2 is not None: + # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] + + with (network): + # encode clip adapter here so embeds are active for tokenizer + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('encode_clip_vision_embeds'): + if has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True + ) + else: + # just do a blank one + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ), + is_training=True, + has_been_preprocessed=True, + drop=True + ) + # it will be injected into the tokenizer when called + self.adapter(conditional_clip_embeds) + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg): + quad_count = random.randint(1, 4) + self.adapter.train() + self.adapter.trigger_pre_te( + tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + batch_tensor=batch.tensor if not is_reg else None, + batch_size=noisy_latents.shape[0] + ) + + with self.timer('encode_prompt'): + unconditional_embeds = None + prompt_kwargs = {} + if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: + prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + with torch.set_grad_enabled(False): + if batch.prompt_embeds is not None: + # use the cached embeds + conditional_embeds = batch.prompt_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + else: + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + if self.cached_trigger_embeds is not None and not is_reg: + embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + conditional_embeds = concat_prompt_embeds( + [embeds_to_use] * noisy_latents.shape[0] + ) + if self.train_config.do_cfg: + unconditional_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + unconditional_embeds = concat_prompt_embeds( + [unconditional_embeds] * noisy_latents.shape[0] + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + elif grad_on_text_encoder: + with torch.set_grad_enabled(True): + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + if self.sd.encode_control_in_text_embeddings and batch.control_tensor_list is not None: + prompt_kwargs['control_images'] = batch.control_tensor_list + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + if self.train_config.diff_output_preservation: + dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] + dop_prompts_2 = None + if prompt_2 is not None: + dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] + self.diff_output_preservation_embeds = self.sd.encode_prompt( + dop_prompts, dop_prompts_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + if self.train_config.do_cfg: + unconditional_embeds = unconditional_embeds.detach() + + if self.decorator: + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds + ) + if self.train_config.do_cfg: + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, + is_unconditional=True + ) + + # flush() + pred_kwargs = {} + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals + + if self.adapter and isinstance(self.adapter, IPAdapter): + with self.timer('encode_adapter_embeds'): + # number of images to do if doing a quad image + quad_count = random.randint(1, 4) + image_size = self.adapter.input_size + if has_clip_image_embeds: + # todo handle reg images better than this + if is_reg: + # get unconditional image embeds from cache + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + if self.train_config.do_cfg: + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + else: + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds_unconditional, + quad_count=quad_count + ) + elif is_reg: + # we will zero it out in the img embedder + clip_images = torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach() + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images, + drop=True, + is_training=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + elif has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + # do cfg on clip embeds to normalize the embeddings for when doing cfg + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + drop=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + else: + print_acc("No Clip Image") + print_acc([file_item.path for file_item in batch.file_items]) + raise ValueError("Could not find clip image") + + if not self.adapter_config.train_image_encoder: + # we are not training the image encoder, so we need to detach the embeds + conditional_clip_embeds = conditional_clip_embeds.detach() + if self.train_config.do_cfg: + unconditional_clip_embeds = unconditional_clip_embeds.detach() + + with self.timer('encode_adapter'): + self.adapter.train() + conditional_embeds = self.adapter( + conditional_embeds.detach(), + conditional_clip_embeds, + is_unconditional=False + ) + if self.train_config.do_cfg: + unconditional_embeds = self.adapter( + unconditional_embeds.detach(), + unconditional_clip_embeds, + is_unconditional=True + ) + else: + # wipe out unconsitional + self.adapter.last_unconditional = None + + if self.adapter and isinstance(self.adapter, ReferenceAdapter): + # pass in our scheduler + self.adapter.noise_scheduler = self.lr_scheduler + if has_clip_image or has_adapter_img: + img_to_use = clip_images if has_clip_image else adapter_images + # currently 0-1 needs to be -1 to 1 + reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) + self.adapter.set_reference_images(reference_images) + self.adapter.noise_scheduler = self.sd.noise_scheduler + elif is_reg: + self.adapter.set_blank_reference_images(noisy_latents.shape[0]) + else: + self.adapter.set_reference_images(None) + + prior_pred = None + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + do_correct_pred_norm_prior = self.train_config.correct_pred_norm + + do_guidance_prior = False + + if batch.unconditional_latents is not None: + # for this not that, we need a prior pred to normalize + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + if guidance_type == 'tnt': + do_guidance_prior = True + + if (( + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): + with self.timer('prior predict'): + prior_embeds_to_use = conditional_embeds + # use diff_output_preservation embeds if doing dfe + if self.train_config.diff_output_preservation: + prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + + if self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + prior_embeds_to_use = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=prior_embeds_to_use, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + unconditional_embeds=unconditional_embeds, + conditioned_prompts=conditioned_prompts + ) + if prior_pred is not None: + prior_pred = prior_pred.detach() + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']): + quad_count = random.randint(1, 4) + self.adapter.train() + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=conditional_embeds, + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + if self.train_config.do_cfg and unconditional_embeds is not None: + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=unconditional_embeds, + is_training=True, + has_been_preprocessed=True, + is_unconditional=True, + quad_count=quad_count + ) + + if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: + self.adapter.add_extra_values(batch.extra_values.detach()) + + if self.train_config.do_cfg: + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), + is_unconditional=True) + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if self.train_config.do_cfg: + raise ValueError("ControlNetModel is not supported with CFG") + with torch.set_grad_enabled(self.adapter is not None): + adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + # add_text_embeds is pooled_prompt_embeds for sdxl + added_cond_kwargs = {} + if self.sd.is_xl: + added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds + added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) + down_block_res_samples, mid_block_res_sample = adapter( + noisy_latents, + timesteps, + encoder_hidden_states=conditional_embeds.text_embeds, + controlnet_cond=adapter_images, + conditioning_scale=1.0, + guess_mode=False, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + pred_kwargs['down_block_additional_residuals'] = down_block_res_samples + pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list): + batch_size = noisy_latents.shape[0] + # update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1] + self._guidance_loss_target_batch = [ + random.uniform( + self.train_config.guidance_loss_target[0], + self.train_config.guidance_loss_target[1] + ) for _ in range(batch_size) + ] + + self.before_unet_predict() + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + with self.timer('condition_noisy_latents'): + # do it for the model + noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) + if self.adapter and isinstance(self.adapter, CustomAdapter): + noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + + if self.train_config.timestep_type == 'next_sample': + with self.timer('next_sample_step'): + with torch.no_grad(): + + stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] + stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] + stepped_timesteps = torch.stack(stepped_timesteps, dim=0) + + # do a sample at the current timestep and step it, then determine new noise + next_sample_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + stepped_latents = self.sd.step_scheduler( + next_sample_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) + # stepped latents is our new noisy latents. Now we need to determine noise in the current sample + noisy_latents = stepped_latents + original_samples = batch.latents.to(self.device_torch, dtype=dtype) + # todo calc next timestep, for now this may work as it + t_01 = (stepped_timesteps / 1000).to(original_samples.device) + if len(stepped_latents.shape) == 4: + t_01 = t_01.view(-1, 1, 1, 1) + elif len(stepped_latents.shape) == 5: + t_01 = t_01.view(-1, 1, 1, 1, 1) + else: + raise ValueError("Unknown stepped latents shape", stepped_latents.shape) + next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 + noise = next_sample_noise + timesteps = stepped_timesteps + # do a prior pred if we have an unconditional image, we will swap out the giadance later + if batch.unconditional_latents is not None or self.do_guided_loss: + # do guided loss + loss = self.get_guided_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + + elif self.train_config.loss_type == 'mean_flow': + loss = self.get_mean_flow_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + prior_pred=prior_pred, + ) + else: + with self.timer('predict_unet'): + noise_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + is_primary_pred=True, + **pred_kwargs + ) + self.after_unet_predict() + + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + prior_to_calculate_loss = prior_pred + # if we are doing diff_output_preservation and not noing inverted masked prior + # then we need to send none here so it will not target the prior + doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation + if doing_preservation and not do_inverted_masked_prior: + prior_to_calculate_loss = None + + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) + + if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: + with torch.no_grad(): + if self.train_config.diff_output_preservation: + preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + elif self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + preservation_embeds = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + preservation_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier + preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier + self.additional_logs['loss/normal'] = loss.item() + self.additional_logs['loss/preservation'] = preservation_loss.item() + loss = loss + preservation_loss + + # check if nan + if torch.isnan(loss): + print_acc("loss is nan") + loss = torch.zeros_like(loss).requires_grad_(True) + + with self.timer('backward'): + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + self.accelerator.backward(loss) + + return loss.detach() + # flush() + + def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): + if isinstance(batch, list): + batch_list = batch + else: + batch_list = [batch] + total_loss = None + self.optimizer.zero_grad() + for batch in batch_list: + if self.sd.is_multistage: + # handle multistage switching + if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries: + # iterate to make sure we only train trainable_multistage_boundaries + while True: + self.steps_this_boundary = 0 + self.current_boundary_index += 1 + if self.current_boundary_index >= len(self.sd.multistage_boundaries): + self.current_boundary_index = 0 + if self.current_boundary_index in self.sd.trainable_multistage_boundaries: + # if this boundary is trainable, we can stop looking + break + loss = self.train_single_accumulation(batch) + self.steps_this_boundary += 1 + if total_loss is None: + total_loss = loss + else: + total_loss += loss + if len(batch_list) > 1 and self.model_config.low_vram: + torch.cuda.empty_cache() + + + if not self.is_grad_accumulation_step: + # fix this for multi params + if self.train_config.optimizer != 'adafactor': + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # only step if we are not accumulating + with self.timer('optimizer_step'): + self.optimizer.step() + + self.optimizer.zero_grad(set_to_none=True) + if self.adapter and isinstance(self.adapter, CustomAdapter): + self.adapter.post_weight_update() + if self.ema is not None: + with self.timer('ema_update'): + self.ema.update() + else: + # gradient accumulation. Just a place for breakpoint + pass + + # TODO Should we only step scheduler on grad step? If so, need to recalculate last step + with self.timer('scheduler_step'): + self.lr_scheduler.step() + + if self.embedding is not None: + with self.timer('restore_embeddings'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('restore_adapter'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.adapter.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': (total_loss / len(batch_list)).item()} + ) + + self.end_of_training_loop() + + return loss_dict diff --git a/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py.bak.20260627-204105 b/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py.bak.20260627-204105 new file mode 100644 index 0000000000000000000000000000000000000000..3c803742f5718a207c6a76fe878e5f2f9579bbe1 --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/SDTrainer.py.bak.20260627-204105 @@ -0,0 +1,2183 @@ +import os +import random +from collections import OrderedDict +from typing import Union, Literal, List, Optional + +import numpy as np +from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel + +import torch.functional as F +from safetensors.torch import load_file +from torch.utils.data import DataLoader, ConcatDataset + +from toolkit import train_tools +from toolkit.basic import value_map, adain, get_mean_std +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.config_modules import GenerateImageConfig +from toolkit.data_loader import get_dataloader_datasets +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO +from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType +from toolkit.image_utils import show_tensors, show_latents +from toolkit.ip_adapter import IPAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.print import print_acc +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ + apply_learnable_snr_gos, LearnableSNRGamma +import gc +import torch +from jobs.process import BaseSDTrainProcess +from torchvision import transforms +from diffusers import EMAModel +import math +from toolkit.train_tools import precondition_model_outputs_flow_match +from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe +from toolkit.util.losses import wavelet_loss, stepped_loss +import torch.nn.functional as F +from toolkit.unloader import unload_text_encoder +from PIL import Image +from torchvision.transforms import functional as TF +from toolkit.basic import flush + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class SDTrainer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] + self.do_prior_prediction = False + self.do_long_prompts = False + self.do_guided_loss = False + self.taesd: Optional[AutoencoderTiny] = None + + self._clip_image_embeds_unconditional: Union[List[str], None] = None + self.negative_prompt_pool: Union[List[str], None] = None + self.batch_negative_prompt: Union[List[str], None] = None + + self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + + self.do_grad_scale = True + if self.is_fine_tuning and self.is_bfloat: + self.do_grad_scale = False + if self.adapter_config is not None: + if self.adapter_config.train: + self.do_grad_scale = False + + # if self.train_config.dtype in ["fp16", "float16"]: + # # patch the scaler to allow fp16 training + # org_unscale_grads = self.scaler._unscale_grads_ + # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + # return org_unscale_grads(optimizer, inv_scale, found_inf, True) + # self.scaler._unscale_grads_ = _unscale_grads_replacer + + self.cached_blank_embeds: Optional[PromptEmbeds] = None + self.cached_trigger_embeds: Optional[PromptEmbeds] = None + self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None + + self.dfe: Optional[DiffusionFeatureExtractor] = None + self.unconditional_embeds = None + + if self.train_config.diff_output_preservation: + if self.trigger_word is None: + raise ValueError("diff_output_preservation requires a trigger_word to be set") + if self.network_config is None: + raise ValueError("diff_output_preservation requires a network to be set") + if self.train_config.train_text_encoder: + raise ValueError("diff_output_preservation is not supported with train_text_encoder") + + if self.train_config.blank_prompt_preservation: + if self.network_config is None: + raise ValueError("blank_prompt_preservation requires a network to be set") + + if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation: + # always do a prior prediction when doing output preservation + self.do_prior_prediction = True + + # store the loss target for a batch so we can use it in a loss + self._guidance_loss_target_batch: float = 0.0 + if isinstance(self.train_config.guidance_loss_target, (int, float)): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target) + elif isinstance(self.train_config.guidance_loss_target, list): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) + else: + raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") + + + def before_model_load(self): + pass + + def cache_sample_prompts(self): + if self.train_config.disable_sampling: + return + if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0: + # cache all the samples + self.sd.sample_prompts_cache = [] + sample_folder = os.path.join(self.save_root, 'samples') + output_path = os.path.join(sample_folder, 'test.jpg') + for i in range(len(self.sample_config.prompts)): + sample_item = self.sample_config.samples[i] + prompt = self.sample_config.prompts[i] + + # needed so we can autoparse the prompt to handle flags + gen_img_config = GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + negative_prompt=sample_item.neg, + output_path=output_path, + ctrl_img=sample_item.ctrl_img, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, + ) + + has_control_images = False + if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None: + has_control_images = True + # see if we need to encode the control images + if self.sd.encode_control_in_text_embeddings and has_control_images: + + ctrl_img_list = [] + + if gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_img_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_img_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_img_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + + + positive = self.sd.encode_prompt( + gen_img_config.prompt, + control_images=ctrl_img + ).to('cpu') + negative = self.sd.encode_prompt( + gen_img_config.negative_prompt, + control_images=ctrl_img + ).to('cpu') + else: + positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') + negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') + + self.sd.sample_prompts_cache.append({ + 'conditional': positive, + 'unconditional': negative + }) + + + def before_dataset_load(self): + self.assistant_adapter = None + # get adapter assistant if one is set + if self.train_config.adapter_assist_name_or_path is not None: + adapter_path = self.train_config.adapter_assist_name_or_path + + if self.train_config.adapter_assist_type == "t2i": + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch) + elif self.train_config.adapter_assist_type == "control_net": + self.assistant_adapter = ControlNetModel.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + else: + raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") + + self.assistant_adapter.eval() + self.assistant_adapter.requires_grad_(False) + flush() + if self.train_config.train_turbo and self.train_config.show_turbo_outputs: + if self.model_config.is_xl: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + else: + self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", + torch_dtype=get_torch_dtype(self.train_config.dtype)) + self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch) + self.taesd.eval() + self.taesd.requires_grad_(False) + + def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to('cpu') + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + + kwargs['control_images'] = control_image + self.unconditional_embeds = self.sd.encode_prompt( + [self.train_config.unconditional_prompt], + long_prompts=self.do_long_prompts, + **kwargs + ).to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + + if self.train_config.do_prior_divergence: + self.do_prior_prediction = True + # move vae to device if we did not cache latents + if not self.is_latents_cached: + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + else: + # offload it. Already cached + self.sd.vae.to('cpu') + flush() + add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) + if self.adapter is not None: + self.adapter.to(self.device_torch) + + # check if we have regs and using adapter and caching clip embeddings + has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 + is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) + + if has_reg and is_caching_clip_embeddings: + # we need a list of unconditional clip image embeds from other datasets to handle regs + unconditional_clip_image_embeds = [] + datasets = get_dataloader_datasets(self.data_loader) + for i in range(len(datasets)): + unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache + + if len(unconditional_clip_image_embeds) == 0: + raise ValueError("No unconditional clip image embeds found. This should not happen") + + self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + + if self.train_config.negative_prompt is not None: + if os.path.exists(self.train_config.negative_prompt): + with open(self.train_config.negative_prompt, 'r') as f: + self.negative_prompt_pool = f.readlines() + # remove empty + self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] + else: + # single prompt + self.negative_prompt_pool = [self.train_config.negative_prompt] + + # handle unload text encoder + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + print_acc("Caching embeddings and unloading text encoder") + with torch.no_grad(): + if self.train_config.train_text_encoder: + raise ValueError("Cannot unload text encoder if training text encoder") + # cache embeddings + self.sd.text_encoder_to(self.device_torch) + encode_kwargs = {} + if self.sd.encode_control_in_text_embeddings: + # just do a blank image for unconditionals + control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + encode_kwargs['control_images'] = control_image + self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) + if self.trigger_word is not None: + self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs) + if self.train_config.diff_output_preservation: + self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) + + self.cache_sample_prompts() + + print_acc("\n***** UNLOADING TEXT ENCODER *****") + if self.is_caching_text_embeddings: + print_acc("Embeddings cached to disk. We dont need the text encoder anymore") + else: + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("***********************************") + print_acc("") + + # unload the text encoder + if self.is_caching_text_embeddings: + unload_text_encoder(self.sd) + else: + # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. + # keep legacy usage for now. + self.sd.text_encoder_to("cpu") + flush() + + if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: + # make sure we have this if not unloading + self.cached_blank_embeds = self.sd.encode_prompt("").to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + + if self.train_config.diffusion_feature_extractor_path is not None: + vae = self.sd.vae + # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": + # vae = self.sd.vae + self.dfe = load_dfe( + self.train_config.diffusion_feature_extractor_path, + vae=vae, + sd=self.sd + ) + self.dfe.to(self.device_torch) + if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: + # must be set to train for gradient checkpointing to work + self.dfe.vision_encoder.train() + self.dfe.vision_encoder.gradient_checkpointing = True + elif hasattr(self.dfe, 'model') and self.train_config.gradient_checkpointing: + if hasattr(self.dfe.model, 'enable_gradient_checkpointing'): + self.dfe.model.train() + self.dfe.model.enable_gradient_checkpointing() + if hasattr(self.dfe.model, 'gradient_checkpointing_enable'): + self.dfe.model.train() + self.dfe.model.gradient_checkpointing_enable() + elif hasattr(self.dfe.model, 'gradient_checkpointing'): + self.dfe.model.train() + self.dfe.model.gradient_checkpointing = True + else: + print_acc("Warning: Could not enable gradient checkpointing on diffusion feature extractor model.") + else: + self.dfe.eval() + + # enable gradient checkpointing on the vae + if vae is not None and self.train_config.gradient_checkpointing: + try: + vae.enable_gradient_checkpointing() + vae.train() + except: + pass + + + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): + # to process turbo learning, we make one big step from our current timestep to the end + # we then denoise the prediction on that remaining step and target our loss to our target latents + # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so. + # needs to be done on each item in batch as they may all have different timesteps + batch_size = pred.shape[0] + pred_chunks = torch.chunk(pred, batch_size, dim=0) + noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0) + timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0) + latent_chunks = torch.chunk(batch.latents, batch_size, dim=0) + noise_chunks = torch.chunk(noise, batch_size, dim=0) + + with torch.no_grad(): + # set the timesteps to 1000 so we can capture them to calculate the sigmas + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach() + + train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach() + + # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step + self.sd.noise_scheduler.set_timesteps( + 1, + device=self.device_torch + ) + + denoised_pred_chunks = [] + target_pred_chunks = [] + + for i in range(batch_size): + pred_item = pred_chunks[i] + noisy_latents_item = noisy_latents_chunks[i] + timesteps_item = timesteps_chunks[i] + latents_item = latent_chunks[i] + noise_item = noise_chunks[i] + with torch.no_grad(): + timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] + single_step_timestep_schedule = [timesteps_item.squeeze().item()] + # extract the sigma idx for our midpoint timestep + sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) + + end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) + end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) + + # add noise to our target + + # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step + # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach() + self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach() + # set our single timstep + self.sd.noise_scheduler.timesteps = torch.from_numpy( + np.array(single_step_timestep_schedule, dtype=np.float32) + ).to(device=self.device_torch) + + # set the step index to None so it will be recalculated on first step + self.sd.noise_scheduler._step_index = None + + denoised_latent = self.sd.noise_scheduler.step( + pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False + )[0] + + residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( + self.train_config.dtype)) + # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) + denoised_latent = denoised_latent - residual_noise + + denoised_pred_chunks.append(denoised_latent) + + denoised_latents = torch.cat(denoised_pred_chunks, dim=0) + # set the scheduler back to the original timesteps + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.config.num_train_timesteps, + device=self.device_torch + ) + + output = denoised_latents / self.sd.vae.config['scaling_factor'] + output = self.sd.vae.decode(output).sample + + if self.train_config.show_turbo_outputs: + # since we are completely denoising, we can show them here + with torch.no_grad(): + show_tensors(output) + + # we return our big partial step denoised latents as our pred and our untouched latents as our target. + # you can do mse against the two here or run the denoised through the vae for pixel space loss against the + # input tensor images. + + return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + + # you can expand these in a child class to make customization easier + def calculate_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + prior_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target + is_reg = any(batch.get_is_reg_list()) + additional_loss = 0.0 + + prior_mask_multiplier = None + target_mask_multiplier = None + dtype = get_torch_dtype(self.train_config.dtype) + + has_mask = batch.mask_tensor is not None + + with torch.no_grad(): + loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) + + if self.train_config.match_noise_norm: + # match the norm of the noise + noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred = noise_pred * (noise_norm / noise_pred_norm) + + if self.train_config.pred_scaler != 1.0: + noise_pred = noise_pred * self.train_config.pred_scaler + + target = None + + if self.train_config.target_noise_multiplier != 1.0: + noise = noise * self.train_config.target_noise_multiplier + + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) + + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std + + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier + + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust + + eps = 1e-5 + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() + + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: + assert not self.train_config.train_turbo + with torch.no_grad(): + prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + lat_height = batch.latents.shape[3] + lat_width = batch.latents.shape[4] + else: + lat_height = batch.latents.shape[2] + lat_width = batch.latents.shape[3] + # resize to size of noise_pred + prior_mask = torch.nn.functional.interpolate(prior_mask, size=(lat_height, lat_width), mode='bicubic') + # stack first channel to match channels of noise_pred + prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) + + if len(noise_pred.shape) == 5: + prior_mask = prior_mask.unsqueeze(2) # add time dimension back for video + prior_mask = prior_mask.repeat(1, 1, noise_pred.shape[2], 1, 1) + + prior_mask_multiplier = 1.0 - prior_mask + + # scale so it is a mean of 1 + prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() + if hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: + target = (noise - batch.latents).detach() + else: + target = noise + elif prior_pred is not None and not self.train_config.do_prior_divergence: + assert not self.train_config.train_turbo + # matching adapter prediction + target = prior_pred + elif self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) + elif self.train_config.do_signal_amplification: + if not self.sd.is_flow_matching: + raise ValueError("Signal amplification is only supported for flow matching models") + with torch.no_grad(): + nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype) + nas = nas * self.train_config.signal_amplification_strength + while len(nas.shape) < len(noise.shape): + nas = nas.unsqueeze(-1) + aug = batch.latents * nas + target = noise - (batch.latents + aug) + target = target.detach() + elif hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + + elif self.sd.is_flow_matching: + # forward ODE + target = (noise - batch.latents).detach() + # reverse ODE + # target = (batch.latents - noise).detach() + else: + target = noise + + if self.dfe is not None: + if self.dfe.version == 1: + model = self.sd + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + self.sd.noise_scheduler._step_index = None + self.sd.noise_scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index] + sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) + sl = stepped_latents + if len(sl.shape) == 5: + # video B,C,T,H,W + sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = sl.shape + sl = sl.reshape(b * t, c, h, w) + pred_features = self.dfe(sl.float()) + with torch.no_grad(): + bl = batch.latents + bl = bl.to(self.sd.vae.device) + if len(bl.shape) == 5: + # video B,C,T,H,W + bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W + b, t, c, h, w = bl.shape + bl = bl.reshape(b * t, c, h, w) + target_features = self.dfe(bl.float()) + # scale dfe so it is weaker at higher noise levels + dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) + + dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ + self.train_config.diffusion_feature_extractor_weight * dfe_scaler + additional_loss += dfe_loss.mean() + elif self.dfe.version == 2: + # version 2 + # do diffusion feature extraction on target + with torch.no_grad(): + rectified_flow_target = noise.float() - batch.latents.float() + target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + + # do diffusion feature extraction on prediction + pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) + + dfe_loss = 0.0 + for i in range(len(target_feature_list)): + dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") + + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 + elif self.dfe.version in [3, 4, 5, 6, 7, 8, 9, 10]: + dfe_loss = self.dfe( + noise=noise, + noise_pred=noise_pred, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + scheduler=self.sd.noise_scheduler + ) + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight + else: + raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") + + if self.train_config.do_guidance_loss: + with torch.no_grad(): + # we make cached blank prompt embeds that match the batch size + unconditional_embeds = concat_prompt_embeds( + [self.unconditional_embeds] * noisy_latents.shape[0], + ) + unconditional_target = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=timesteps, + conditional_embeds=unconditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + is_video = len(target.shape) == 5 + + if self.train_config.do_guidance_loss_cfg_zero: + # zero cfg + # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 + batch_size = target.shape[0] + positive_flat = target.view(batch_size, -1) + negative_flat = unconditional_target.view(batch_size, -1) + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + alpha = st_star + + alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + else: + alpha = 1.0 + + guidance_scale = self._guidance_loss_target_batch + if isinstance(guidance_scale, list): + guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) + guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) + + unconditional_target = unconditional_target * alpha + target = unconditional_target + guidance_scale * (target - unconditional_target) + + if self.train_config.do_differential_guidance: + with torch.no_grad(): + guidance_scale = self.train_config.differential_guidance_scale + target = noise_pred + guidance_scale * (target - noise_pred) + + if target is None: + target = noise + + pred = noise_pred + + if self.train_config.train_turbo: + pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) + + ignore_snr = False + + if loss_target == 'source' or loss_target == 'unaugmented': + assert not self.train_config.train_turbo + # ignore_snr = True + if batch.sigmas is None: + raise ValueError("Batch sigmas is None. This should not happen") + + # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 + denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents + weighing = batch.sigmas ** -2.0 + if loss_target == 'source': + # denoise the latent and compare to the latent in the batch + target = batch.latents + elif loss_target == 'unaugmented': + # we have to encode images into latents for now + # we also denoise as the unaugmented tensor is not a noisy diffirental + with torch.no_grad(): + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) + unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier + target = unaugmented_latents.detach() + + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = target # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # mse loss without reduction + loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) + loss = loss_per_element + else: + local_loss_scale = 1.0 + if self.train_config.t0_loss_target or self.train_config.do_fft_loss: + # do the loss on a stepped timestep 0 prediction + # doto handle doing priors, preservations, masking, etc + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + # min 0.001 + tv = torch.clamp(tv, min=0.001) + + # step latent, use here or with do_fft_loss + t0 = noisy_latents - tv * noise_pred + + if self.train_config.t0_loss_target: + # replace the loss targets and pred + target = batch.latents.detach() + pred = t0 + # handle velocity equiv loss if set. This scales t0 loss to match velocity of flowmatchhing loss + if self.train_config.t0_velocity_equiv_weight: + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + local_loss_scale = velocity_equiv_weight + + if self.train_config.do_fft_loss: + with torch.no_grad(): + target_mag = torch.fft.rfft2(batch.latents.to(t0.device).float(), norm="ortho").abs() + pred_mag = torch.fft.rfft2(t0.float(), norm="ortho").abs() + fft_loss = F.mse_loss(pred_mag, target_mag, reduction="none") + if self.train_config.do_fft_velocity_equiv_weight: + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + fft_loss = fft_loss * velocity_equiv_weight + additional_loss += fft_loss.mean() + if self.train_config.loss_type == "pseudo_huber": + diff = pred.float() - target.float() + c=0.01 + loss =(torch.sqrt(diff.pow(2) + c ** 2) - c) + elif self.train_config.loss_type == "mae": + loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + elif self.train_config.loss_type == "wavelet": + loss = wavelet_loss(pred, batch.latents, noise) + elif self.train_config.loss_type == "stepped": + loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler) + # the way this loss works, it is low, increase it to match predictable LR effects + loss = loss * 10.0 + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + loss = loss * local_loss_scale + + # apply model specific loss scaling + loss = self.sd.scale_loss(loss) + + do_weighted_timesteps = False + if self.sd.is_flow_matching: + if self.train_config.linear_timesteps or self.train_config.linear_timesteps2: + do_weighted_timesteps = True + if self.train_config.timestep_type == "weighted": + # use the noise scheduler to get the weights for the timesteps + do_weighted_timesteps = True + + # handle linear timesteps and only adjust the weight of the timesteps + if do_weighted_timesteps: + # calculate the weights for the timesteps + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( + timesteps, + v2=self.train_config.linear_timesteps2, + timestep_type=self.train_config.timestep_type + ).to(loss.device, dtype=loss.dtype) + if len(loss.shape) == 4: + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + elif len(loss.shape) == 5: + timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() + loss = loss * timestep_weight + + if self.train_config.do_prior_divergence and prior_pred is not None: + loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + + if self.train_config.train_turbo: + mask_multiplier = mask_multiplier[:, 3:, :, :] + # resize to the size of the loss + mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') + + # multiply by our mask + try: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + mask_multiplier = mask_multiplier.unsqueeze(2) # add time dimension back for video + mask_multiplier = mask_multiplier.repeat(1, 1, noise_pred.shape[2], 1, 1) + loss = loss * mask_multiplier + except Exception as e: + # todo handle mask with video models + print("Could not apply mask multiplier to loss") + print(e) + pass + + prior_loss = None + if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: + assert not self.train_config.train_turbo + if self.train_config.loss_type == "mae": + prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") + else: + prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") + + prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier + if torch.isnan(prior_loss).any(): + print_acc("Prior loss is nan") + prior_loss = None + else: + if len(noise_pred.shape) == 5: + # video B,C,T,H,W + prior_loss = prior_loss.mean([1, 2, 3, 4]) + else: + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss + # loss = loss + prior_loss + # loss = loss + prior_loss + if len(noise_pred.shape) == 5: + loss = loss.mean([1, 2, 3, 4]) + else: + loss = loss.mean([1, 2, 3]) + # apply loss multiplier before prior loss + # multiply by our mask + try: + loss = loss * loss_multiplier + except: + # todo handle mask with video models + pass + if prior_loss is not None: + loss = loss + prior_loss + + if not self.train_config.train_turbo: + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, + fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # check for audio loss + if batch.audio_pred is not None and batch.audio_target is not None: + audio_loss = torch.nn.functional.mse_loss(batch.audio_pred.float(), batch.audio_target.float(), reduction="mean") + audio_loss = audio_loss * self.train_config.audio_loss_multiplier + loss = loss + audio_loss + + # check for additional losses + if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: + + loss = loss + self.adapter.additional_loss.mean() + self.adapter.additional_loss = None + + if self.train_config.target_norm_std: + # seperate out the batch and channels + pred_std = noise_pred.std([2, 3], keepdim=True) + norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() + loss = loss + norm_std_loss + + + loss = loss + additional_loss + + if self.train_config.max_loss_debug and self.train_config.max_loss is not None: + if loss.item() > self.train_config.max_loss: + print_acc(f"Loss {loss.item()} is greater than max loss {self.train_config.max_loss}. Clipping to max loss.") + print_acc(f"timesteps: {timesteps}") + + if self.train_config.max_loss is not None: + loss = torch.clamp(loss, max=self.train_config.max_loss) + + return loss + + def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): + return batch + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + loss = get_guidance_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + sd=self.sd, + unconditional_embeds=unconditional_embeds, + train_config=self.train_config, + **kwargs + ) + + return loss + + + # ------------------------------------------------------------------ + # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative + # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) + # This version avoids jvp / double-back-prop issues with Flash-Attention + # adapted from the work of lodestonerock + # ------------------------------------------------------------------ + def get_mean_flow_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + dtype = get_torch_dtype(self.train_config.dtype) + total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000 + base_eps = 1e-3 + min_time_gap = 1e-2 + + with torch.no_grad(): + num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps + batch_size = batch.latents.shape[0] + timestep_t_list = [] + timestep_r_list = [] + + for i in range(batch_size): + t1 = random.randint(0, num_train_timesteps - 1) + t2 = random.randint(0, num_train_timesteps - 1) + t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)] + t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)] + if (t_t - t_r).item() < min_time_gap * 1000: + scaled_time_gap = min_time_gap * 1000 + if t_t.item() + scaled_time_gap > 1000: + t_r = t_r - scaled_time_gap + else: + t_t = t_t + scaled_time_gap + timestep_t_list.append(t_t) + timestep_r_list.append(t_r) + + timesteps_t = torch.stack(timestep_t_list, dim=0).float() + timesteps_r = torch.stack(timestep_r_list, dim=0).float() + + t_frac = timesteps_t / total_steps # [0,1] + r_frac = timesteps_r / total_steps # [0,1] + + latents_clean = batch.latents.to(dtype) + noise_sample = noise.to(dtype) + + lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None] + + eps = base_eps + + # concatenate timesteps as input for u(z, r, t) + timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps + + # model predicts u(z, r, t) + u_pred = self.predict_noise( + noisy_latents=lerp_vector.to(dtype), + timesteps=timesteps_cat.to(dtype), + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + with torch.no_grad(): + t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0) + lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None] + timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps + + u_perturbed = self.predict_noise( + noisy_latents=lerp_perturbed.to(dtype), + timesteps=timesteps_cat_perturbed.to(dtype), + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + # compute du/dt via finite difference (detached) + du_dt = (u_perturbed - u_pred).detach() / eps + # du_dt = (u_perturbed - u_pred).detach() + du_dt = du_dt.to(dtype) + + + time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype) + time_gap.clamp(min=1e-4) + u_shifted = u_pred + time_gap * du_dt + # u_shifted = u_pred + du_dt / time_gap + # u_shifted = u_pred + + # a step is done like this: + # stepped_latent = model_input + (timestep_next - timestep) * model_output + + # flow target velocity + # v_target = (noise_sample - latents_clean) / time_gap + # flux predicts opposite of velocity, so we need to invert it + v_target = (latents_clean - noise_sample) / time_gap + + # compute loss + loss = torch.nn.functional.mse_loss( + u_shifted.float(), + v_target.float(), + reduction='none' + ) + + with torch.no_grad(): + pure_loss = loss.mean().detach() + pure_loss.requires_grad_(True) + + loss = loss.mean() + if loss.item() > 1e3: + pass + self.accelerator.backward(loss) + return pure_loss + + + + def get_prior_prediction( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + conditioned_prompts=None, + **kwargs + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + can_disable_adapter = False + was_adapter_active = False + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or + isinstance(self.adapter, ReferenceAdapter) or + (isinstance(self.adapter, CustomAdapter)) + ): + can_disable_adapter = True + was_adapter_active = self.adapter.is_active + self.adapter.is_active = False + + if self.train_config.unload_text_encoder and self.adapter is not None and not isinstance(self.adapter, CustomAdapter): + raise ValueError("Prior predictions currently do not support unloading text encoder with adapter") + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + embeds_to_use = conditional_embeds.clone().detach() + # handle clip vision adapter by removing triggers from prompt and replacing with the class name + if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: + prompt_list = batch.get_caption_list() + class_name = '' + + triggers = ['[trigger]', '[name]'] + remove_tokens = [] + + if self.embed_config is not None: + triggers.append(self.embed_config.trigger) + for i in range(1, self.embed_config.tokens): + remove_tokens.append(f"{self.embed_config.trigger}_{i}") + if self.embed_config.trigger_class_name is not None: + class_name = self.embed_config.trigger_class_name + + if self.adapter is not None: + triggers.append(self.adapter_config.trigger) + for i in range(1, self.adapter_config.num_tokens): + remove_tokens.append(f"{self.adapter_config.trigger}_{i}") + if self.adapter_config.trigger_class_name is not None: + class_name = self.adapter_config.trigger_class_name + + for idx, prompt in enumerate(prompt_list): + for remove_token in remove_tokens: + prompt = prompt.replace(remove_token, '') + for trigger in triggers: + prompt = prompt.replace(trigger, class_name) + prompt_list[idx] = prompt + + if batch.prompt_embeds is not None: + embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) + else: + prompt_kwargs = {} + if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: + prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + embeds_to_use = self.sd.encode_prompt( + prompt_list, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype, + **prompt_kwargs + ).detach() + + # dont use network on this + # self.network.multiplier = 0.0 + self.sd.unet.eval() + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2: + # we need to remove the image embeds from the prompt except for flux + embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() + end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens + embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.clone().detach() + unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch + + prior_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, + rescale_cfg=self.train_config.cfg_rescale, + batch=batch, + **pred_kwargs # adapter residuals in here + ) + if was_unet_training: + self.sd.unet.train() + prior_pred = prior_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: + del pred_kwargs['down_intrablock_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: + del pred_kwargs['mid_block_additional_residual'] + + if can_disable_adapter: + self.adapter.is_active = was_adapter_active + # restore network + # self.network.multiplier = network_weight_list + if self.network is not None: + self.network.is_active = was_network_active + return prior_pred + + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def end_of_training_loop(self): + pass + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + batch: Optional['DataLoaderBatchDTO'] = None, + is_primary_pred: bool = False, + **kwargs, + ): + dtype = get_torch_dtype(self.train_config.dtype) + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch + return self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, + detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, + bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, + batch=batch, + **kwargs + ) + + + def train_single_accumulation(self, batch: DataLoaderBatchDTO): + print_acc("[spock-debug] train_single_accumulation: ENTER", flush=True) + with torch.no_grad(): + self.timer.start('preprocess_batch') + print_acc("[spock-debug] train_single_accumulation: preprocess_batch", flush=True) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_raw(batch) + batch = self.preprocess_batch(batch) + print_acc("[spock-debug] train_single_accumulation: preprocess_batch DONE", flush=True) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_processed(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + # do prior pred if prior regularization batch + do_reg_prior = False + if any([batch.file_items[idx].prior_reg for idx in range(len(batch.file_items))]): + do_reg_prior = True + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False + loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + for idx, file_item in enumerate(batch.file_items): + if file_item.is_reg: + loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + is_reg = True + + adapter_images = None + sigmas = None + if has_adapter_img and (self.adapter or self.assistant_adapter): + with self.timer('get_adapter_images'): + # todo move this to data loader + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + # match in channels + if self.assistant_adapter is not None: + in_channels = self.assistant_adapter.config.in_channels + if adapter_images.shape[1] != in_channels: + # we need to match the channels + adapter_images = adapter_images[:, :in_channels, :, :] + else: + raise NotImplementedError("Adapter images now must be loaded with dataloader") + + clip_images = None + if has_clip_image: + with self.timer('get_clip_images'): + # todo move this to data loader + if batch.clip_image_tensor is not None: + clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() + + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + if batch.mask_tensor is not None and self.sd.do_masked_loss: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + if len(noisy_latents.shape) == 5: + # video B,C,T,H,W + h = noisy_latents.shape[3] + w = noisy_latents.shape[4] + else: + h = noisy_latents.shape[2] + w = noisy_latents.shape[3] + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(h, w) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + # make avg 1.0 + mask_multiplier = mask_multiplier / mask_multiplier.mean() + + def get_adapter_multiplier(): + if self.adapter and isinstance(self.adapter, T2IAdapter): + # training a t2i adapter, not using as assistant. + return 1.0 + elif match_adapter_assist: + # training a texture. We want it high + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + else: + # training with assistance, we want it low + # adapter_strength_min = 0.4 + # adapter_strength_max = 0.7 + adapter_strength_min = 0.5 + adapter_strength_max = 1.1 + + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale + + # flush() + with self.timer('grad_setup'): + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding is not None: + grad_on_text_encoder = True + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + grad_on_text_encoder = True + + if self.adapter_config and self.adapter_config.type == 'te_augmenter': + grad_on_text_encoder = True + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + + prompts_1 = conditioned_prompts + prompts_2 = None + if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: + prompts_1 = batch.get_caption_short_list() + prompts_2 = conditioned_prompts + + # make the batch splits + if self.train_config.single_item_batching: + if self.model_config.refiner_name_or_path is not None: + raise ValueError("Single item batching is not supported when training the refiner") + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in prompts_1] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + if clip_images is not None: + clip_images_list = torch.chunk(clip_images, batch_size, dim=0) + else: + clip_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) + if prompts_2 is None: + prompt_2_list = [None for _ in range(batch_size)] + else: + prompt_2_list = [[prompt] for prompt in prompts_2] + + else: + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [prompts_1] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + clip_images_list = [clip_images] + mask_multiplier_list = [mask_multiplier] + if prompts_2 is None: + prompt_2_list = [None] + else: + prompt_2_list = [prompts_2] + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + clip_images_list, + mask_multiplier_list, + prompt_2_list + ): + + # if self.train_config.negative_prompt is not None: + # # add negative prompt + # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in + # range(len(conditioned_prompts))] + # if prompt_2 is not None: + # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] + + with (network): + # encode clip adapter here so embeds are active for tokenizer + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('encode_clip_vision_embeds'): + if has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True + ) + else: + # just do a blank one + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ), + is_training=True, + has_been_preprocessed=True, + drop=True + ) + # it will be injected into the tokenizer when called + self.adapter(conditional_clip_embeds) + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg): + quad_count = random.randint(1, 4) + self.adapter.train() + self.adapter.trigger_pre_te( + tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + batch_tensor=batch.tensor if not is_reg else None, + batch_size=noisy_latents.shape[0] + ) + + with self.timer('encode_prompt'): + unconditional_embeds = None + prompt_kwargs = {} + if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: + prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: + with torch.set_grad_enabled(False): + if batch.prompt_embeds is not None: + # use the cached embeds + conditional_embeds = batch.prompt_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + else: + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + if self.cached_trigger_embeds is not None and not is_reg: + embeds_to_use = self.cached_trigger_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + conditional_embeds = concat_prompt_embeds( + [embeds_to_use] * noisy_latents.shape[0] + ) + if self.train_config.do_cfg: + unconditional_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + unconditional_embeds = concat_prompt_embeds( + [unconditional_embeds] * noisy_latents.shape[0] + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + elif grad_on_text_encoder: + with torch.set_grad_enabled(True): + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + if self.sd.encode_control_in_text_embeddings and batch.control_tensor_list is not None: + prompt_kwargs['control_images'] = batch.control_tensor_list + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + if self.train_config.diff_output_preservation: + dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] + dop_prompts_2 = None + if prompt_2 is not None: + dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] + self.diff_output_preservation_embeds = self.sd.encode_prompt( + dop_prompts, dop_prompts_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts, + **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype) + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + if self.train_config.do_cfg: + unconditional_embeds = unconditional_embeds.detach() + + if self.decorator: + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds + ) + if self.train_config.do_cfg: + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, + is_unconditional=True + ) + + # flush() + pred_kwargs = {} + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals + + if self.adapter and isinstance(self.adapter, IPAdapter): + with self.timer('encode_adapter_embeds'): + # number of images to do if doing a quad image + quad_count = random.randint(1, 4) + image_size = self.adapter.input_size + if has_clip_image_embeds: + # todo handle reg images better than this + if is_reg: + # get unconditional image embeds from cache + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + if self.train_config.do_cfg: + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + else: + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds_unconditional, + quad_count=quad_count + ) + elif is_reg: + # we will zero it out in the img embedder + clip_images = torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach() + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images, + drop=True, + is_training=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + elif has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + # do cfg on clip embeds to normalize the embeddings for when doing cfg + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + drop=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + else: + print_acc("No Clip Image") + print_acc([file_item.path for file_item in batch.file_items]) + raise ValueError("Could not find clip image") + + if not self.adapter_config.train_image_encoder: + # we are not training the image encoder, so we need to detach the embeds + conditional_clip_embeds = conditional_clip_embeds.detach() + if self.train_config.do_cfg: + unconditional_clip_embeds = unconditional_clip_embeds.detach() + + with self.timer('encode_adapter'): + self.adapter.train() + conditional_embeds = self.adapter( + conditional_embeds.detach(), + conditional_clip_embeds, + is_unconditional=False + ) + if self.train_config.do_cfg: + unconditional_embeds = self.adapter( + unconditional_embeds.detach(), + unconditional_clip_embeds, + is_unconditional=True + ) + else: + # wipe out unconsitional + self.adapter.last_unconditional = None + + if self.adapter and isinstance(self.adapter, ReferenceAdapter): + # pass in our scheduler + self.adapter.noise_scheduler = self.lr_scheduler + if has_clip_image or has_adapter_img: + img_to_use = clip_images if has_clip_image else adapter_images + # currently 0-1 needs to be -1 to 1 + reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) + self.adapter.set_reference_images(reference_images) + self.adapter.noise_scheduler = self.sd.noise_scheduler + elif is_reg: + self.adapter.set_blank_reference_images(noisy_latents.shape[0]) + else: + self.adapter.set_reference_images(None) + + prior_pred = None + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + do_correct_pred_norm_prior = self.train_config.correct_pred_norm + + do_guidance_prior = False + + if batch.unconditional_latents is not None: + # for this not that, we need a prior pred to normalize + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + if guidance_type == 'tnt': + do_guidance_prior = True + + if (( + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): + with self.timer('prior predict'): + prior_embeds_to_use = conditional_embeds + # use diff_output_preservation embeds if doing dfe + if self.train_config.diff_output_preservation: + prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + + if self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + prior_embeds_to_use = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=prior_embeds_to_use, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + unconditional_embeds=unconditional_embeds, + conditioned_prompts=conditioned_prompts + ) + if prior_pred is not None: + prior_pred = prior_pred.detach() + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']): + quad_count = random.randint(1, 4) + self.adapter.train() + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=conditional_embeds, + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + if self.train_config.do_cfg and unconditional_embeds is not None: + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=unconditional_embeds, + is_training=True, + has_been_preprocessed=True, + is_unconditional=True, + quad_count=quad_count + ) + + if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: + self.adapter.add_extra_values(batch.extra_values.detach()) + + if self.train_config.do_cfg: + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), + is_unconditional=True) + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if self.train_config.do_cfg: + raise ValueError("ControlNetModel is not supported with CFG") + with torch.set_grad_enabled(self.adapter is not None): + adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + # add_text_embeds is pooled_prompt_embeds for sdxl + added_cond_kwargs = {} + if self.sd.is_xl: + added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds + added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) + down_block_res_samples, mid_block_res_sample = adapter( + noisy_latents, + timesteps, + encoder_hidden_states=conditional_embeds.text_embeds, + controlnet_cond=adapter_images, + conditioning_scale=1.0, + guess_mode=False, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + pred_kwargs['down_block_additional_residuals'] = down_block_res_samples + pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list): + batch_size = noisy_latents.shape[0] + # update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1] + self._guidance_loss_target_batch = [ + random.uniform( + self.train_config.guidance_loss_target[0], + self.train_config.guidance_loss_target[1] + ) for _ in range(batch_size) + ] + + self.before_unet_predict() + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + with self.timer('condition_noisy_latents'): + # do it for the model + noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) + if self.adapter and isinstance(self.adapter, CustomAdapter): + noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + + if self.train_config.timestep_type == 'next_sample': + with self.timer('next_sample_step'): + with torch.no_grad(): + + stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] + stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] + stepped_timesteps = torch.stack(stepped_timesteps, dim=0) + + # do a sample at the current timestep and step it, then determine new noise + next_sample_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + stepped_latents = self.sd.step_scheduler( + next_sample_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) + # stepped latents is our new noisy latents. Now we need to determine noise in the current sample + noisy_latents = stepped_latents + original_samples = batch.latents.to(self.device_torch, dtype=dtype) + # todo calc next timestep, for now this may work as it + t_01 = (stepped_timesteps / 1000).to(original_samples.device) + if len(stepped_latents.shape) == 4: + t_01 = t_01.view(-1, 1, 1, 1) + elif len(stepped_latents.shape) == 5: + t_01 = t_01.view(-1, 1, 1, 1, 1) + else: + raise ValueError("Unknown stepped latents shape", stepped_latents.shape) + next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 + noise = next_sample_noise + timesteps = stepped_timesteps + # do a prior pred if we have an unconditional image, we will swap out the giadance later + if batch.unconditional_latents is not None or self.do_guided_loss: + # do guided loss + loss = self.get_guided_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + + elif self.train_config.loss_type == 'mean_flow': + loss = self.get_mean_flow_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + prior_pred=prior_pred, + ) + else: + with self.timer('predict_unet'): + noise_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + is_primary_pred=True, + **pred_kwargs + ) + self.after_unet_predict() + + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + prior_to_calculate_loss = prior_pred + # if we are doing diff_output_preservation and not noing inverted masked prior + # then we need to send none here so it will not target the prior + doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation + if doing_preservation and not do_inverted_masked_prior: + prior_to_calculate_loss = None + + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) + + if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: + with torch.no_grad(): + if self.train_config.diff_output_preservation: + preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) + elif self.train_config.blank_prompt_preservation: + blank_embeds = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) + preservation_embeds = concat_prompt_embeds( + [blank_embeds] * noisy_latents.shape[0] + ) + preservation_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier + preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier + self.additional_logs['loss/normal'] = loss.item() + self.additional_logs['loss/preservation'] = preservation_loss.item() + loss = loss + preservation_loss + + # check if nan + if torch.isnan(loss): + print_acc("loss is nan") + loss = torch.zeros_like(loss).requires_grad_(True) + + with self.timer('backward'): + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + self.accelerator.backward(loss) + + return loss.detach() + # flush() + + def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): + if isinstance(batch, list): + batch_list = batch + else: + batch_list = [batch] + total_loss = None + self.optimizer.zero_grad() + for batch in batch_list: + if self.sd.is_multistage: + # handle multistage switching + if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries: + # iterate to make sure we only train trainable_multistage_boundaries + while True: + self.steps_this_boundary = 0 + self.current_boundary_index += 1 + if self.current_boundary_index >= len(self.sd.multistage_boundaries): + self.current_boundary_index = 0 + if self.current_boundary_index in self.sd.trainable_multistage_boundaries: + # if this boundary is trainable, we can stop looking + break + loss = self.train_single_accumulation(batch) + self.steps_this_boundary += 1 + if total_loss is None: + total_loss = loss + else: + total_loss += loss + if len(batch_list) > 1 and self.model_config.low_vram: + torch.cuda.empty_cache() + + + if not self.is_grad_accumulation_step: + # fix this for multi params + if self.train_config.optimizer != 'adafactor': + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # only step if we are not accumulating + with self.timer('optimizer_step'): + self.optimizer.step() + + self.optimizer.zero_grad(set_to_none=True) + if self.adapter and isinstance(self.adapter, CustomAdapter): + self.adapter.post_weight_update() + if self.ema is not None: + with self.timer('ema_update'): + self.ema.update() + else: + # gradient accumulation. Just a place for breakpoint + pass + + # TODO Should we only step scheduler on grad step? If so, need to recalculate last step + with self.timer('scheduler_step'): + self.lr_scheduler.step() + + if self.embedding is not None: + with self.timer('restore_embeddings'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('restore_adapter'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.adapter.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': (total_loss / len(batch_list)).item()} + ) + + self.end_of_training_loop() + + return loss_dict diff --git a/ai-toolkit/extensions_built_in/sd_trainer/UITrainer.py b/ai-toolkit/extensions_built_in/sd_trainer/UITrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5fa796cd72dc17d834aeaf2dfc2d5aba84c55b --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/UITrainer.py @@ -0,0 +1,295 @@ +from collections import OrderedDict +import os +import sqlite3 +import asyncio +import concurrent.futures +from extensions_built_in.sd_trainer.SDTrainer import SDTrainer +from typing import Literal, Optional +import threading +import time +import signal + +AITK_Status = Literal["running", "stopped", "error", "completed"] + + +class UITrainer(SDTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(UITrainer, self).__init__(process_id, job, config, **kwargs) + self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") + if not os.path.exists(self.sqlite_db_path): + raise Exception( + f"SQLite database not found at {self.sqlite_db_path}") + print(f"Using SQLite database at {self.sqlite_db_path}") + self.job_id = os.environ.get("AITK_JOB_ID", None) + self.job_id = self.job_id.strip() if self.job_id is not None else None + print(f"Job ID: \"{self.job_id}\"") + if self.job_id is None: + raise Exception("AITK_JOB_ID not set") + self.is_stopping = False + # Create a thread pool for database operations + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Track all async tasks + self._async_tasks = [] + # Initialize the status + self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + # self.start_stop_watcher(interval_sec=2.0) + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) + + def _run_async_operation(self, coro): + """Helper method to run an async coroutine and track the task.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a task and track it + if loop.is_running(): + task = asyncio.run_coroutine_threadsafe(coro, loop) + self._async_tasks.append(asyncio.wrap_future(task)) + else: + task = loop.create_task(coro) + self._async_tasks.append(task) + loop.run_until_complete(task) + + async def _execute_db_operation(self, operation_func): + """Execute a database operation in a separate thread to avoid blocking.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.thread_pool, operation_func) + + def _db_connect(self): + """Create a new connection for each operation to avoid locking.""" + conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0) + conn.isolation_level = None # Enable autocommit mode + return conn + + def should_stop(self): + def _check_stop(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT stop FROM Job WHERE id = ?", (self.job_id,)) + stop = cursor.fetchone() + return False if stop is None else stop[0] == 1 + + return _check_stop() + + def should_return_to_queue(self): + def _check_return_to_queue(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,)) + return_to_queue = cursor.fetchone() + return False if return_to_queue is None else return_to_queue[0] == 1 + + return _check_return_to_queue() + + def maybe_stop(self): + if self.should_stop(): + self._run_async_operation( + self._update_status("stopped", "Job stopped")) + self.is_stopping = True + raise Exception("Job stopped") + if self.should_return_to_queue(): + self._run_async_operation( + self._update_status("queued", "Job queued")) + self.is_stopping = True + raise Exception("Job returning to queue") + + async def _update_key(self, key, value): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + # Convert the value to string if it's not already + if isinstance(value, str): + value_to_insert = value + else: + value_to_insert = str(value) + + # Use parameterized query for both the column name and value + update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" + cursor.execute( + update_query, (value_to_insert, self.job_id)) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_step(self): + """Non-blocking update of the step count.""" + if self.accelerator.is_main_process: + self._run_async_operation(self._update_key("step", self.step_num)) + + def update_db_key(self, key, value): + """Non-blocking update a key in the database.""" + if self.accelerator.is_main_process: + self._run_async_operation(self._update_key(key, value)) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + if info is not None: + cursor.execute( + "UPDATE Job SET status = ?, info = ? WHERE id = ?", + (status, info, self.job_id) + ) + else: + cursor.execute( + "UPDATE Job SET status = ? WHERE id = ?", + (status, self.job_id) + ) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_status(self, status: AITK_Status, info: Optional[str] = None): + """Non-blocking update of status.""" + if self.accelerator.is_main_process: + self._run_async_operation(self._update_status(status, info)) + + async def wait_for_all_async(self): + """Wait for all tracked async operations to complete.""" + if not self._async_tasks: + return + + try: + await asyncio.gather(*self._async_tasks) + except Exception as e: + pass + finally: + # Clear the task list after completion + self._async_tasks.clear() + + def on_error(self, e: Exception): + super(UITrainer, self).on_error(e) + if self.accelerator.is_main_process and not self.is_stopping: + self.update_status("error", str(e)) + self.update_db_key("step", self.last_save_step) + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def handle_timing_print_hook(self, timing_dict): + if "train_loop" not in timing_dict: + print("train_loop not found in timing_dict", timing_dict) + return + seconds_per_iter = timing_dict["train_loop"] + # determine iter/sec or sec/iter + if seconds_per_iter < 1: + iters_per_sec = 1 / seconds_per_iter + self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") + else: + self.update_db_key( + "speed_string", f"{seconds_per_iter:.2f} sec/iter") + + def done_hook(self): + super(UITrainer, self).done_hook() + self.update_status("completed", "Training completed") + # Wait for all async operations to finish before shutting down + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def end_step_hook(self): + super(UITrainer, self).end_step_hook() + self.update_step() + self.maybe_stop() + + def hook_before_model_load(self): + super().hook_before_model_load() + self.maybe_stop() + self.update_status("running", "Loading model") + + def before_dataset_load(self): + super().before_dataset_load() + self.maybe_stop() + self.update_status("running", "Loading dataset") + + def hook_before_train_loop(self): + super().hook_before_train_loop() + self.maybe_stop() + self.update_step() + self.update_status("running", "Training") + self.timer.add_after_print_hook(self.handle_timing_print_hook) + + def status_update_hook_func(self, string): + self.update_status("running", string) + + def hook_after_sd_init_before_load(self): + super().hook_after_sd_init_before_load() + self.maybe_stop() + self.sd.add_status_update_hook(self.status_update_hook_func) + + def sample_step_hook(self, img_num, total_imgs): + super().sample_step_hook(img_num, total_imgs) + self.maybe_stop() + self.update_status( + "running", f"Generating images - {img_num + 1}/{total_imgs}") + + def sample(self, step=None, is_first=False): + self.maybe_stop() + total_imgs = len(self.sample_config.prompts) + self.update_status("running", f"Generating images - 0/{total_imgs}") + super().sample(step, is_first) + self.maybe_stop() + self.update_status("running", "Training") + + def save(self, step=None): + self.maybe_stop() + self.update_status("running", "Saving model") + super().save(step) + self.maybe_stop() + self.update_status("running", "Training") diff --git a/ai-toolkit/extensions_built_in/sd_trainer/__init__.py b/ai-toolkit/extensions_built_in/sd_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..065ff8185b9d00b343ad8c41080422644d133e36 --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/__init__.py @@ -0,0 +1,70 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class SDTrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "sd_trainer" + + # name is the name of the extension for printing + name = "SD Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .SDTrainer import SDTrainer + + return SDTrainer + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class UITrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "ui_trainer" + + # name is the name of the extension for printing + name = "UI Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .UITrainer import UITrainer + + return UITrainer + + +# This is a universal trainer that can be from ui or api +class DiffusionTrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "diffusion_trainer" + + # name is the name of the extension for printing + name = "Diffusion Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .DiffusionTrainer import DiffusionTrainer + + return DiffusionTrainer + + +# for backwards compatability +class TextualInversionTrainer(SDTrainerExtension): + uid = "textual_inversion_trainer" + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + SDTrainerExtension, + TextualInversionTrainer, + UITrainerExtension, + DiffusionTrainerExtension, +] diff --git a/ai-toolkit/extensions_built_in/sd_trainer/config/train.example.yaml b/ai-toolkit/extensions_built_in/sd_trainer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793d5d55b9a282d58b53f0e6fafba0b5aa66f7af --- /dev/null +++ b/ai-toolkit/extensions_built_in/sd_trainer/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/ai-toolkit/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py b/ai-toolkit/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..857cfa7578f29ff1343f0650209f81b6720d4ce0 --- /dev/null +++ b/ai-toolkit/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py @@ -0,0 +1,533 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds, build_latent_image_batch_for_prompt_pair +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random + +import random +from collections import OrderedDict +from tqdm import tqdm + +from toolkit.config_modules import SliderConfig +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +from toolkit.prompt_utils import \ + EncodedPromptPair, ACTION_TYPES_SLIDER, \ + EncodedAnchor, concat_prompt_pairs, \ + concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \ + split_prompt_pairs + +import torch + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class UltimateSliderConfig(SliderConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) + self.img_loss_weight: float = kwargs.get('img_loss_weight', 1.0) + self.cfg_loss_weight: float = kwargs.get('cfg_loss_weight', 1.0) + self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] + + +class UltimateSliderTrainerProcess(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = UltimateSliderConfig(**self.get_conf('slider', {})) + + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + # keep track of prompt chunk size + self.prompt_chunk_size = 1 + + # store a list of all the prompts from the dataset so we can cache it + self.dataset_prompts = [] + self.train_with_dataset = self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0 + + def load_datasets(self): + if self.data_loader is None and \ + self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0: + print(f"Loading datasets") + datasets = [] + for dataset in self.slider_config.datasets: + print(f" - Dataset: {dataset.pair_folder}") + config = { + 'path': dataset.pair_folder, + 'size': dataset.size, + 'default_prompt': dataset.target_class, + 'network_weight': dataset.network_weight, + 'pos_weight': dataset.pos_weight, + 'neg_weight': dataset.neg_weight, + 'pos_folder': dataset.pos_folder, + 'neg_folder': dataset.neg_folder, + } + image_dataset = PairedImageDataset(config) + datasets.append(image_dataset) + + # capture all the prompts from it so we can cache the embeds + self.dataset_prompts += image_dataset.get_all_prompts() + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.train_config.batch_size, + shuffle=True, + num_workers=2 + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + # load any datasets if they were passed + self.load_datasets() + + # read line by line from file + if self.slider_config.prompt_file: + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") + with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Found {len(self.prompt_txt_list)} prompts.") + + if not self.slider_config.prompt_tensors: + print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.") + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + + cache = PromptEmbedsCache() + + # get encoded latents for our prompts + with torch.no_grad(): + # list of neutrals. Can come from file or be empty + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + # build the prompts to cache + prompts_to_cache = [] + for neutral in neutral_list: + for target in self.slider_config.targets: + prompt_list = [ + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets reverse + ] + prompts_to_cache += prompt_list + + # remove duplicates + prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + + # trim to max steps if max steps is lower than prompt count + prompts_to_cache = prompts_to_cache[:self.train_config.steps] + + if len(self.dataset_prompts) > 0: + # add the prompts from the dataset + prompts_to_cache += self.dataset_prompts + + # encode them + cache = encode_prompts_to_cache( + prompt_list=prompts_to_cache, + sd=self.sd, + cache=cache, + prompt_tensor_file=self.slider_config.prompt_tensors + ) + + prompt_pairs = [] + prompt_batches = [] + for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False): + for target in self.slider_config.targets: + prompt_pair_batch = build_prompt_pair_batch_from_cache( + cache=cache, + target=target, + neutral=neutral, + + ) + if self.slider_config.batch_full_slide: + # concat the prompt pairs + # this allows us to run the entire 4 part process in one shot (for slider) + self.prompt_chunk_size = 4 + concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu') + prompt_pairs += [concat_prompt_pair_batch] + else: + self.prompt_chunk_size = 1 + # do them one at a time (probably not necessary after new optimizations) + prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + # end hook_before_train_loop + + # move vae to device so we can encode on the fly + # todo cache latents + self.sd.vae.to(self.device_torch) + self.sd.vae.eval() + self.sd.vae.requires_grad_(False) + + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + + with torch.no_grad(): + ### LOOP SETUP ### + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + ### TARGET_PROMPTS ### + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) + + ### PREP REFERENCE IMAGES ### + + imgs, prompts, network_weights = batch + network_pos_weight, network_neg_weight = network_weights + + if isinstance(network_pos_weight, torch.Tensor): + network_pos_weight = network_pos_weight.item() + if isinstance(network_neg_weight, torch.Tensor): + network_neg_weight = network_neg_weight.item() + + # get an array of random floats between -weight_jitter and weight_jitter + weight_jitter = self.slider_config.weight_jitter + if weight_jitter > 0.0: + jitter_list = random.uniform(-weight_jitter, weight_jitter) + network_pos_weight += jitter_list + network_neg_weight += (jitter_list * -1.0) + + # if items in network_weight list are tensors, convert them to floats + imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) + # split batched images in half so left is negative and right is positive + negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + + height = positive_images.shape[2] + width = positive_images.shape[3] + batch_size = positive_images.shape[0] + + positive_latents = self.sd.encode_images(positive_images) + negative_latents = self.sd.encode_images(negative_images) + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + current_timestep_index = timesteps.item() + current_timestep = noise_scheduler.timesteps[current_timestep_index] + timesteps = timesteps.long() + + # get noise + noise_positive = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + noise_negative = noise_positive.clone() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + + ### CFG SLIDER TRAINING PREP ### + + # get CFG txt latents + noisy_cfg_latents = build_latent_image_batch_for_prompt_pair( + pos_latent=noisy_positive_latents, + neg_latent=noisy_negative_latents, + prompt_pair=prompt_pair, + prompt_chunk_size=self.prompt_chunk_size, + ) + noisy_cfg_latents.requires_grad = False + + assert not self.network.is_active + + # 4.20 GB RAM for 512x512 + positive_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.negative_target, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + positive_latents.requires_grad = False + + neutral_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.empty_prompt, # positive prompt (normally neutral + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + neutral_latents.requires_grad = False + + unconditional_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.positive_target, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + unconditional_latents.requires_grad = False + + positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0) + prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) + noisy_cfg_latents_chunks = torch.chunk(noisy_cfg_latents, self.prompt_chunk_size, dim=0) + assert len(prompt_pair_chunks) == len(noisy_cfg_latents_chunks) + + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) + noise = torch.cat([noise_positive, noise_negative], dim=0) + timesteps = torch.cat([timesteps, timesteps], dim=0) + network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0] + + flush() + + loss_float = None + loss_mirror_float = None + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # TODO allow both processed to train text encoder, for now, we just to unet and cache all text encodes + # if training text encoder enable grads, else do context of no grad + # with torch.set_grad_enabled(self.train_config.train_text_encoder): + # # text encoding + # embedding_list = [] + # # embed the prompts + # for prompt in prompts: + # embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + # embedding_list.append(embedding) + # conditional_embeds = concat_prompt_embeds(embedding_list) + # conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + + if self.train_with_dataset: + embedding_list = [] + with torch.set_grad_enabled(self.train_config.train_text_encoder): + for prompt in prompts: + # get embedding form cache + embedding = self.prompt_cache[prompt] + embedding = embedding.to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + # double up so we can do both sides of the slider + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + else: + # throw error. Not supported yet + raise Exception("Datasets and targets required for ultimate slider") + + if self.model_config.is_xl: + # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop + network_multiplier_list = network_multiplier + noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) + noise_list = torch.chunk(noise, 2, dim=0) + timesteps_list = torch.chunk(timesteps, 2, dim=0) + conditional_embeds_list = split_prompt_embeds(conditional_embeds) + else: + network_multiplier_list = [network_multiplier] + noisy_latent_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditional_embeds_list = [conditional_embeds] + + ## DO REFERENCE IMAGE TRAINING ## + + reference_image_losses = [] + # allow to chunk it out to save vram + for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip( + network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list + ): + with self.network: + assert self.network.is_active + + self.network.multiplier = network_multiplier + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + # todo add snr gamma here + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + loss = loss * self.slider_config.img_loss_weight + loss_slide_float = loss.item() + + loss_float = loss.item() + reference_image_losses.append(loss_float) + + # back propagate loss to free ram + loss.backward() + flush() + + ## DO CFG SLIDER TRAINING ## + + cfg_loss_list = [] + + with self.network: + assert self.network.is_active + for prompt_pair_chunk, \ + noisy_cfg_latent_chunk, \ + positive_latents_chunk, \ + neutral_latents_chunk, \ + unconditional_latents_chunk \ + in zip( + prompt_pair_chunks, + noisy_cfg_latents_chunks, + positive_latents_chunks, + neutral_latents_chunks, + unconditional_latents_chunks, + ): + self.network.multiplier = prompt_pair_chunk.multiplier_list + + target_latents = self.sd.predict_noise( + latents=noisy_cfg_latent_chunk, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair_chunk.positive_target, # negative prompt + prompt_pair_chunk.target_class, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk) + + # make offset multiplier based on actions + offset_multiplier_list = [] + for action in prompt_pair_chunk.action_list: + if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE: + offset_multiplier_list += [-1.0] + elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE: + offset_multiplier_list += [1.0] + + offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype) + # make offset multiplier match rank of offset + offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1) + offset *= offset_multiplier + + offset_neutral = neutral_latents_chunk + # offsets are already adjusted on a per-batch basis + offset_neutral += offset + + # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing + loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, + self.train_config.min_snr_gamma) + + loss = loss.mean() * prompt_pair_chunk.weight * self.slider_config.cfg_loss_weight + + loss.backward() + cfg_loss_list.append(loss.item()) + del target_latents + del offset_neutral + del loss + flush() + + # apply gradients + optimizer.step() + lr_scheduler.step() + + # reset network + self.network.multiplier = 1.0 + + reference_image_loss = sum(reference_image_losses) / len(reference_image_losses) if len( + reference_image_losses) > 0 else 0.0 + cfg_loss = sum(cfg_loss_list) / len(cfg_loss_list) if len(cfg_loss_list) > 0 else 0.0 + + loss_dict = OrderedDict({ + 'loss/img': reference_image_loss, + 'loss/cfg': cfg_loss, + }) + + return loss_dict + # end hook_train_loop diff --git a/ai-toolkit/extensions_built_in/ultimate_slider_trainer/__init__.py b/ai-toolkit/extensions_built_in/ultimate_slider_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7006db52a5882713dd071b683cb5892c2a0d00 --- /dev/null +++ b/ai-toolkit/extensions_built_in/ultimate_slider_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class UltimateSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "ultimate_slider_trainer" + + # name is the name of the extension for printing + name = "Ultimate Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .UltimateSliderTrainerProcess import UltimateSliderTrainerProcess + return UltimateSliderTrainerProcess + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + UltimateSliderTrainer +] diff --git a/ai-toolkit/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml b/ai-toolkit/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b0f4734ae09fb7e942e33089014ffe59cfd7720 --- /dev/null +++ b/ai-toolkit/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/ai-toolkit/flux_train_ui.py b/ai-toolkit/flux_train_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..54411d58c0beffbd185930d6da090f09bd660c4a --- /dev/null +++ b/ai-toolkit/flux_train_ui.py @@ -0,0 +1,414 @@ +import os +from huggingface_hub import whoami +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys + +# Add the current working directory to the Python path +sys.path.insert(0, os.getcwd()) + +import gradio as gr +from PIL import Image +import torch +import uuid +import os +import shutil +import json +import yaml +from slugify import slugify +from transformers import AutoProcessor, AutoModelForCausalLM + +sys.path.insert(0, "ai-toolkit") +from toolkit.job import get_job + +MAX_IMAGES = 150 + +def load_captioning(uploaded_files, concept_sentence): + uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')] + txt_files = [file for file in uploaded_files if file.endswith('.txt')] + txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files} + updates = [] + if len(uploaded_images) <= 1: + raise gr.Error( + "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)" + ) + elif len(uploaded_images) > MAX_IMAGES: + raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training") + # Update for the captioning_area + # for _ in range(3): + updates.append(gr.update(visible=True)) + # Update visibility and image for each captioning row and image + for i in range(1, MAX_IMAGES + 1): + # Determine if the current row and image should be visible + visible = i <= len(uploaded_images) + + # Update visibility of the captioning row + updates.append(gr.update(visible=visible)) + + # Update for image component - display image if available, otherwise hide + image_value = uploaded_images[i - 1] if visible else None + updates.append(gr.update(value=image_value, visible=visible)) + + corresponding_caption = False + if(image_value): + base_name = os.path.splitext(os.path.basename(image_value))[0] + print(base_name) + print(image_value) + if base_name in txt_files_dict: + print("entrou") + with open(txt_files_dict[base_name], 'r') as file: + corresponding_caption = file.read() + + # Update value of captioning area + text_value = corresponding_caption if visible and corresponding_caption else "[trigger]" if visible and concept_sentence else None + updates.append(gr.update(value=text_value, visible=visible)) + + # Update for the sample caption area + updates.append(gr.update(visible=True)) + # Update prompt samples + updates.append(gr.update(placeholder=f'A portrait of person in a bustling cafe {concept_sentence}', value=f'A person in a bustling cafe {concept_sentence}')) + updates.append(gr.update(placeholder=f"A mountainous landscape in the style of {concept_sentence}")) + updates.append(gr.update(placeholder=f"A {concept_sentence} in a mall")) + updates.append(gr.update(visible=True)) + return updates + +def hide_captioning(): + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + +def create_dataset(*inputs): + print("Creating dataset") + images = inputs[0] + destination_folder = str(f"datasets/{uuid.uuid4()}") + if not os.path.exists(destination_folder): + os.makedirs(destination_folder) + + jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl") + with open(jsonl_file_path, "a") as jsonl_file: + for index, image in enumerate(images): + new_image_path = shutil.copy(image, destination_folder) + + original_caption = inputs[index + 1] + file_name = os.path.basename(new_image_path) + + data = {"file_name": file_name, "prompt": original_caption} + + jsonl_file.write(json.dumps(data) + "\n") + + return destination_folder + + +def run_captioning(images, concept_sentence, *captions): + #Load internally to not consume resources for training + device = "cuda" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.float16 + model = AutoModelForCausalLM.from_pretrained( + "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True) + + captions = list(captions) + for i, image_path in enumerate(images): + print(captions[i]) + if isinstance(image_path, str): # If image is a file path + image = Image.open(image_path).convert("RGB") + + prompt = "" + inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) + + generated_ids = model.generate( + input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 + ) + + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + parsed_answer = processor.post_process_generation( + generated_text, task=prompt, image_size=(image.width, image.height) + ) + caption_text = parsed_answer[""].replace("The image shows ", "") + if concept_sentence: + caption_text = f"{caption_text} [trigger]" + captions[i] = caption_text + + yield captions + model.to("cpu") + del model + del processor + +def recursive_update(d, u): + for k, v in u.items(): + if isinstance(v, dict) and v: + d[k] = recursive_update(d.get(k, {}), v) + else: + d[k] = v + return d + +def start_training( + lora_name, + concept_sentence, + steps, + lr, + rank, + model_to_train, + low_vram, + dataset_folder, + sample_1, + sample_2, + sample_3, + use_more_advanced_options, + more_advanced_options, +): + push_to_hub = True + if not lora_name: + raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.") + try: + if whoami()["auth"]["accessToken"]["role"] == "write" or "repo.write" in whoami()["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]: + gr.Info(f"Starting training locally {whoami()['name']}. Your LoRA will be available locally and in Hugging Face after it finishes.") + else: + push_to_hub = False + gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face") + except: + push_to_hub = False + gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face") + + print("Started training") + slugged_lora_name = slugify(lora_name) + + # Load the default config + with open("config/examples/train_lora_flux_24gb.yaml", "r") as f: + config = yaml.safe_load(f) + + # Update the config with user inputs + config["config"]["name"] = slugged_lora_name + config["config"]["process"][0]["model"]["low_vram"] = low_vram + config["config"]["process"][0]["train"]["skip_first_sample"] = True + config["config"]["process"][0]["train"]["steps"] = int(steps) + config["config"]["process"][0]["train"]["lr"] = float(lr) + config["config"]["process"][0]["network"]["linear"] = int(rank) + config["config"]["process"][0]["network"]["linear_alpha"] = int(rank) + config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder + config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub + if(push_to_hub): + try: + username = whoami()["name"] + except: + raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?") + config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}" + config["config"]["process"][0]["save"]["hf_private"] = True + if concept_sentence: + config["config"]["process"][0]["trigger_word"] = concept_sentence + + if sample_1 or sample_2 or sample_3: + config["config"]["process"][0]["train"]["disable_sampling"] = False + config["config"]["process"][0]["sample"]["sample_every"] = steps + config["config"]["process"][0]["sample"]["sample_steps"] = 28 + config["config"]["process"][0]["sample"]["prompts"] = [] + if sample_1: + config["config"]["process"][0]["sample"]["prompts"].append(sample_1) + if sample_2: + config["config"]["process"][0]["sample"]["prompts"].append(sample_2) + if sample_3: + config["config"]["process"][0]["sample"]["prompts"].append(sample_3) + else: + config["config"]["process"][0]["train"]["disable_sampling"] = True + if(model_to_train == "schnell"): + config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell" + config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter" + config["config"]["process"][0]["sample"]["sample_steps"] = 4 + if(use_more_advanced_options): + more_advanced_options_dict = yaml.safe_load(more_advanced_options) + config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict) + print(config) + + # Save the updated config + # generate a random name for the config + random_config_name = str(uuid.uuid4()) + os.makedirs("tmp", exist_ok=True) + config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + + # run the job locally + job = get_job(config_path) + job.run() + job.cleanup() + + return f"Training completed successfully. Model saved as {slugged_lora_name}" + +config_yaml = ''' +device: cuda:0 +model: + is_flux: true + quantize: true +network: + linear: 16 #it will overcome the 'rank' parameter + linear_alpha: 16 #you can have an alpha different than the ranking if you'd like + type: lora +sample: + guidance_scale: 3.5 + height: 1024 + neg: '' #doesn't work for FLUX + sample_every: 1000 + sample_steps: 28 + sampler: flowmatch + seed: 42 + walk_seed: true + width: 1024 +save: + dtype: float16 + hf_private: true + max_step_saves_to_keep: 4 + push_to_hub: true + save_every: 10000 +train: + batch_size: 1 + dtype: bf16 + ema_config: + ema_decay: 0.99 + use_ema: true + gradient_accumulation_steps: 1 + gradient_checkpointing: true + noise_scheduler: flowmatch + optimizer: adamw8bit #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit + train_text_encoder: false #probably doesn't work for flux + train_unet: true +''' + +theme = gr.themes.Monochrome( + text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"), + font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], +) +css = """ +h1{font-size: 2em} +h3{margin-top: 0} +#component-1{text-align:center} +.main_ui_logged_out{opacity: 0.3; pointer-events: none} +.tabitem{border: 0px} +.group_padding{padding: .55em} +""" +with gr.Blocks(theme=theme, css=css) as demo: + gr.Markdown( + """# LoRA Ease for FLUX 🧞‍♂️ +### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)""" + ) + with gr.Column() as main_ui: + with gr.Row(): + lora_name = gr.Textbox( + label="The name of your LoRA", + info="This has to be a unique name", + placeholder="e.g.: Persian Miniature Painting style, Cat Toy", + ) + concept_sentence = gr.Textbox( + label="Trigger word/sentence", + info="Trigger word or sentence to be used", + placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", + interactive=True, + ) + with gr.Group(visible=True) as image_upload: + with gr.Row(): + images = gr.File( + file_types=["image", ".txt"], + label="Upload your images", + file_count="multiple", + interactive=True, + visible=True, + scale=1, + ) + with gr.Column(scale=3, visible=False) as captioning_area: + with gr.Column(): + gr.Markdown( + """# Custom captioning +

You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.

+""", elem_classes="group_padding") + do_captioning = gr.Button("Add AI captions with Florence-2") + output_components = [captioning_area] + caption_list = [] + for i in range(1, MAX_IMAGES + 1): + locals()[f"captioning_row_{i}"] = gr.Row(visible=False) + with locals()[f"captioning_row_{i}"]: + locals()[f"image_{i}"] = gr.Image( + type="filepath", + width=111, + height=111, + min_width=111, + interactive=False, + scale=2, + show_label=False, + show_share_button=False, + show_download_button=False, + ) + locals()[f"caption_{i}"] = gr.Textbox( + label=f"Caption {i}", scale=15, interactive=True + ) + + output_components.append(locals()[f"captioning_row_{i}"]) + output_components.append(locals()[f"image_{i}"]) + output_components.append(locals()[f"caption_{i}"]) + caption_list.append(locals()[f"caption_{i}"]) + + with gr.Accordion("Advanced options", open=False): + steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1) + lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6) + rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4) + model_to_train = gr.Radio(["dev", "schnell"], value="dev", label="Model to train") + low_vram = gr.Checkbox(label="Low VRAM", value=True) + with gr.Accordion("Even more advanced options", open=False): + use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False) + more_advanced_options = gr.Code(config_yaml, language="yaml") + + with gr.Accordion("Sample prompts (optional)", visible=False) as sample: + gr.Markdown( + "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)" + ) + sample_1 = gr.Textbox(label="Test prompt 1") + sample_2 = gr.Textbox(label="Test prompt 2") + sample_3 = gr.Textbox(label="Test prompt 3") + + output_components.append(sample) + output_components.append(sample_1) + output_components.append(sample_2) + output_components.append(sample_3) + start = gr.Button("Start training", visible=False) + output_components.append(start) + progress_area = gr.Markdown("") + + dataset_folder = gr.State() + + images.upload( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + + images.delete( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + + images.clear( + hide_captioning, + outputs=[captioning_area, sample, start] + ) + + start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then( + fn=start_training, + inputs=[ + lora_name, + concept_sentence, + steps, + lr, + rank, + model_to_train, + low_vram, + dataset_folder, + sample_1, + sample_2, + sample_3, + use_more_advanced_options, + more_advanced_options + ], + outputs=progress_area, + ) + + do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list) + +if __name__ == "__main__": + demo.launch(share=True, show_error=True) \ No newline at end of file diff --git a/ai-toolkit/info.py b/ai-toolkit/info.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb2a82ef027b50275bc70dbe4cfb86f644d65e7 --- /dev/null +++ b/ai-toolkit/info.py @@ -0,0 +1,9 @@ +from collections import OrderedDict +from version import VERSION + +v = OrderedDict() +v["name"] = "ai-toolkit" +v["repo"] = "https://github.com/ostris/ai-toolkit" +v["version"] = VERSION + +software_meta = v diff --git a/ai-toolkit/jobs/BaseJob.py b/ai-toolkit/jobs/BaseJob.py new file mode 100644 index 0000000000000000000000000000000000000000..5b339ebe6d942f73f55e14f3de265e7e1db833b7 --- /dev/null +++ b/ai-toolkit/jobs/BaseJob.py @@ -0,0 +1,71 @@ +import importlib +from collections import OrderedDict +from typing import List + +from jobs.process import BaseProcess + + +class BaseJob: + + def __init__(self, config: OrderedDict): + if not config: + raise ValueError('config is required') + self.process: List[BaseProcess] + + self.config = config['config'] + self.raw_config = config + self.job = config['job'] + self.name = self.get_conf('name', required=True) + if 'meta' in config: + self.meta = config['meta'] + else: + self.meta = OrderedDict() + + def get_conf(self, key, default=None, required=False): + if key in self.config: + return self.config[key] + elif required: + raise ValueError(f'config file error. Missing "config.{key}" key') + else: + return default + + def run(self): + print("") + print(f"#############################################") + print(f"# Running job: {self.name}") + print(f"#############################################") + print("") + # implement in child class + # be sure to call super().run() first + pass + + def load_processes(self, process_dict: dict): + # only call if you have processes in this job type + if 'process' not in self.config: + raise ValueError('config file is invalid. Missing "config.process" key') + if len(self.config['process']) == 0: + raise ValueError('config file is invalid. "config.process" must be a list of processes') + + module = importlib.import_module('jobs.process') + + # add the processes + self.process = [] + for i, process in enumerate(self.config['process']): + if 'type' not in process: + raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') + + # check if dict key is process type + if process['type'] in process_dict: + if isinstance(process_dict[process['type']], str): + ProcessClass = getattr(module, process_dict[process['type']]) + else: + # it is the class + ProcessClass = process_dict[process['type']] + self.process.append(ProcessClass(i, self, process)) + else: + raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') + + def cleanup(self): + # if you implement this in child clas, + # be sure to call super().cleanup() LAST + del self diff --git a/ai-toolkit/jobs/ExtensionJob.py b/ai-toolkit/jobs/ExtensionJob.py new file mode 100644 index 0000000000000000000000000000000000000000..def4f8530a8a92c65369cd63a3e69c16bf0bb7de --- /dev/null +++ b/ai-toolkit/jobs/ExtensionJob.py @@ -0,0 +1,22 @@ +import os +from collections import OrderedDict +from jobs import BaseJob +from toolkit.extension import get_all_extensions_process_dict +from toolkit.paths import CONFIG_ROOT + +class ExtensionJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + self.process_dict = get_all_extensions_process_dict() + self.load_processes(self.process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/ai-toolkit/jobs/ExtractJob.py b/ai-toolkit/jobs/ExtractJob.py new file mode 100644 index 0000000000000000000000000000000000000000..d710d4128db5304569357ee05d2fb31fa15c6e39 --- /dev/null +++ b/ai-toolkit/jobs/ExtractJob.py @@ -0,0 +1,58 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from jobs import BaseJob +from toolkit.train_tools import get_torch_dtype + +process_dict = { + 'locon': 'ExtractLoconProcess', + 'lora': 'ExtractLoraProcess', +} + + +class ExtractJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.base_model_path = self.get_conf('base_model', required=True) + self.model_base = None + self.model_base_text_encoder = None + self.model_base_vae = None + self.model_base_unet = None + self.extract_model_path = self.get_conf('extract_model', required=True) + self.model_extract = None + self.model_extract_text_encoder = None + self.model_extract_vae = None + self.model_extract_unet = None + self.extract_unet = self.get_conf('extract_unet', True) + self.extract_text_encoder = self.get_conf('extract_text_encoder', True) + self.dtype = self.get_conf('dtype', 'fp16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.output_folder = self.get_conf('output_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + # load models + print(f"Loading models for extraction") + print(f" - Loading base model: {self.base_model_path}") + # (text_model, vae, unet) + self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) + self.model_base_text_encoder = self.model_base[0] + self.model_base_vae = self.model_base[1] + self.model_base_unet = self.model_base[2] + + print(f" - Loading extract model: {self.extract_model_path}") + self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) + self.model_extract_text_encoder = self.model_extract[0] + self.model_extract_vae = self.model_extract[1] + self.model_extract_unet = self.model_extract[2] + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/ai-toolkit/jobs/GenerateJob.py b/ai-toolkit/jobs/GenerateJob.py new file mode 100644 index 0000000000000000000000000000000000000000..bd57a6ac7a1a97d9e68e86131b9e61ac9922e6d0 --- /dev/null +++ b/ai-toolkit/jobs/GenerateJob.py @@ -0,0 +1,24 @@ +from jobs import BaseJob +from collections import OrderedDict + +process_dict = { + 'to_folder': 'GenerateProcess', +} + + +class GenerateJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/ai-toolkit/jobs/MergeJob.py b/ai-toolkit/jobs/MergeJob.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e3b87b9ff589438d06c56019446f06efb76cda --- /dev/null +++ b/ai-toolkit/jobs/MergeJob.py @@ -0,0 +1,29 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from jobs import BaseJob +from toolkit.train_tools import get_torch_dtype + +process_dict = { +} + + +class MergeJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.dtype = self.get_conf('dtype', 'fp16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/ai-toolkit/jobs/ModJob.py b/ai-toolkit/jobs/ModJob.py new file mode 100644 index 0000000000000000000000000000000000000000..e37990de95a0d2ad78a94f9cdfd6dfbda0cdc529 --- /dev/null +++ b/ai-toolkit/jobs/ModJob.py @@ -0,0 +1,28 @@ +import os +from collections import OrderedDict +from jobs import BaseJob +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype + +process_dict = { + 'rescale_lora': 'ModRescaleLoraProcess', +} + + +class ModJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/ai-toolkit/jobs/TrainJob.py b/ai-toolkit/jobs/TrainJob.py new file mode 100644 index 0000000000000000000000000000000000000000..b4982d26690a0c63d5dbdd9063614308ee94491f --- /dev/null +++ b/ai-toolkit/jobs/TrainJob.py @@ -0,0 +1,44 @@ +import json +import os + +from jobs import BaseJob +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from typing import List +from jobs.process import BaseExtractProcess, TrainFineTuneProcess +from datetime import datetime + + +process_dict = { + 'vae': 'TrainVAEProcess', + 'slider': 'TrainSliderProcess', + 'slider_old': 'TrainSliderProcessOld', + 'lora_hack': 'TrainLoRAHack', + 'rescale_sd': 'TrainSDRescaleProcess', + 'esrgan': 'TrainESRGANProcess', + 'reference': 'TrainReferenceProcess', +} + + +class TrainJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.training_folder = self.get_conf('training_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + # self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) + # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 + self.log_dir = self.get_conf('log_dir', None) + + # loads the processes from the config + self.load_processes(process_dict) + + + def run(self): + super().run() + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/ai-toolkit/jobs/__init__.py b/ai-toolkit/jobs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7da6c22b1ddd0ea9248e5afbf9b2ba014c137c1a --- /dev/null +++ b/ai-toolkit/jobs/__init__.py @@ -0,0 +1,7 @@ +from .BaseJob import BaseJob +from .ExtractJob import ExtractJob +from .TrainJob import TrainJob +from .MergeJob import MergeJob +from .ModJob import ModJob +from .GenerateJob import GenerateJob +from .ExtensionJob import ExtensionJob diff --git a/ai-toolkit/jobs/process/BaseExtensionProcess.py b/ai-toolkit/jobs/process/BaseExtensionProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..b53dc1c498e64bb4adbc2b967b329fdc4a374925 --- /dev/null +++ b/ai-toolkit/jobs/process/BaseExtensionProcess.py @@ -0,0 +1,19 @@ +from collections import OrderedDict +from typing import ForwardRef +from jobs.process.BaseProcess import BaseProcess + + +class BaseExtensionProcess(BaseProcess): + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.progress_bar: ForwardRef('tqdm') = None + + def run(self): + super().run() diff --git a/ai-toolkit/jobs/process/BaseExtractProcess.py b/ai-toolkit/jobs/process/BaseExtractProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..ac10da54d82f15c8264b2799b10b01bb5cf8dc66 --- /dev/null +++ b/ai-toolkit/jobs/process/BaseExtractProcess.py @@ -0,0 +1,86 @@ +import os +from collections import OrderedDict + +from safetensors.torch import save_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors + +from typing import ForwardRef + +from toolkit.train_tools import get_torch_dtype + + +class BaseExtractProcess(BaseProcess): + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.config: OrderedDict + self.output_folder: str + self.output_filename: str + self.output_path: str + self.process_id = process_id + self.job = job + self.config = config + self.dtype = self.get_conf('dtype', self.job.dtype) + self.torch_dtype = get_torch_dtype(self.dtype) + self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet) + self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder) + + def run(self): + # here instead of init because child init needs to go first + self.output_path = self.get_output_path() + # implement in child class + # be sure to call super().run() first + pass + + # you can override this in the child class if you want + # call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this + def get_output_path(self, prefix=None, suffix=None): + config_output_path = self.get_conf('output_path', None) + config_filename = self.get_conf('filename', None) + # replace [name] with name + + if config_output_path is not None: + config_output_path = config_output_path.replace('[name]', self.job.name) + return config_output_path + + if config_output_path is None and config_filename is not None: + # build the output path from the output folder and filename + return os.path.join(self.job.output_folder, config_filename) + + # build our own + + if suffix is None: + # we will just add process it to the end of the filename if there is more than one process + # and no other suffix was given + suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else '' + + if prefix is None: + prefix = '' + + output_filename = f"{prefix}{self.output_filename}{suffix}" + + return os.path.join(self.job.output_folder, output_filename) + + def save(self, state_dict): + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(self.torch_dtype) + state_dict[key] = v + + # having issues with meta + save_file(state_dict, self.output_path, save_meta) + + print(f"Saved to {self.output_path}") diff --git a/ai-toolkit/jobs/process/BaseMergeProcess.py b/ai-toolkit/jobs/process/BaseMergeProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..55dfec68ae62383afae539ff6cb51862033a7e10 --- /dev/null +++ b/ai-toolkit/jobs/process/BaseMergeProcess.py @@ -0,0 +1,46 @@ +import os +from collections import OrderedDict + +from safetensors.torch import save_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype + + +class BaseMergeProcess(BaseProcess): + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.output_path = self.get_conf('output_path', required=True) + self.dtype = self.get_conf('dtype', self.job.dtype) + self.torch_dtype = get_torch_dtype(self.dtype) + + def run(self): + # implement in child class + # be sure to call super().run() first + pass + + def save(self, state_dict): + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(self.torch_dtype) + state_dict[key] = v + + # having issues with meta + save_file(state_dict, self.output_path, save_meta) + + print(f"Saved to {self.output_path}") diff --git a/ai-toolkit/jobs/process/BaseProcess.py b/ai-toolkit/jobs/process/BaseProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..c58724c987abb4521efc2afa9b1a85740f7429b8 --- /dev/null +++ b/ai-toolkit/jobs/process/BaseProcess.py @@ -0,0 +1,64 @@ +import copy +import json +from collections import OrderedDict + +from toolkit.timer import Timer + + +class BaseProcess(object): + + def __init__( + self, + process_id: int, + job: 'BaseJob', + config: OrderedDict + ): + self.process_id = process_id + self.meta: OrderedDict + self.job = job + self.config = config + self.raw_process_config = config + self.name = self.get_conf('name', self.job.name) + self.meta = copy.deepcopy(self.job.meta) + self.timer: Timer = Timer(f'{self.name} Timer') + self.performance_log_every = self.get_conf('performance_log_every', 0) + + print(json.dumps(self.config, indent=4)) + + def on_error(self, e: Exception): + pass + + def get_conf(self, key, default=None, required=False, as_type=None): + # split key by '.' and recursively get the value + keys = key.split('.') + + # see if it exists in the config + value = self.config + for subkey in keys: + if subkey in value: + value = value[subkey] + else: + value = None + break + + if value is not None: + if as_type is not None: + value = as_type(value) + return value + elif required: + raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key') + else: + if as_type is not None and default is not None: + return as_type(default) + return default + + def run(self): + # implement in child class + # be sure to call super().run() first incase something is added here + pass + + def add_meta(self, additional_meta: OrderedDict): + self.meta.update(additional_meta) + + +from jobs import BaseJob diff --git a/ai-toolkit/jobs/process/BaseSDTrainProcess.py b/ai-toolkit/jobs/process/BaseSDTrainProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7945fe5e5d0138d06aa1e5fa38a7cdc49a30bb --- /dev/null +++ b/ai-toolkit/jobs/process/BaseSDTrainProcess.py @@ -0,0 +1,2789 @@ +import copy +import glob +import inspect +import json +import random +import shutil +from collections import OrderedDict +import os +import re +import traceback +from typing import Union, List, Optional + +import numpy as np +import yaml +from diffusers import T2IAdapter, ControlNetModel +from diffusers.training_utils import compute_density_for_timestep_sampling +from safetensors.torch import save_file, load_file +# from lycoris.config import PRESET +from torch.utils.data import DataLoader +import torch +import torch.backends.cuda +from huggingface_hub import HfApi, interpreter_login +from toolkit.memory_management import MemoryManager + +from toolkit.basic import value_map +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.ema import ExponentialMovingAverage +from toolkit.embedding import Embedding +from toolkit.image_utils import show_tensors, show_latents, reduce_contrast +from toolkit.ip_adapter import IPAdapter +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \ + lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE +from toolkit.lycoris_special import LycorisSpecialNetwork +from toolkit.models.decorator import Decorator +from toolkit.network_mixins import Network +from toolkit.optimizer import get_optimizer +from toolkit.paths import CONFIG_ROOT +from toolkit.progress_bar import ToolkitProgressBar +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.sampler import get_sampler +from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \ + load_ip_adapter_model, load_custom_adapter_model + +from toolkit.scheduler import get_lr_scheduler +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.stable_diffusion_model import StableDiffusion + +from jobs.process import BaseTrainProcess +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \ + parse_metadata_from_safetensors +from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight +import gc + +from tqdm import tqdm + +from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \ + DecoratorConfig +from toolkit.logging_aitk import create_logger +from diffusers import FluxTransformer2DModel +from toolkit.accelerator import get_accelerator, unwrap_model +from toolkit.print import print_acc +from accelerate import Accelerator +import transformers +import diffusers +import hashlib + +from toolkit.util.blended_blur_noise import get_blended_blur_noise +from toolkit.util.get_model import get_model_class +from toolkit.basic import flush + + +class BaseSDTrainProcess(BaseTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): + super().__init__(process_id, job, config) + self.accelerator: Accelerator = get_accelerator() + if self.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_error() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + self.sd: StableDiffusion + self.embedding: Union[Embedding, None] = None + + self.custom_pipeline = custom_pipeline + self.step_num = 0 + self.start_step = 0 + self.epoch_num = 0 + self.last_save_step = 0 + # start at 1 so we can do a sample at the start + self.grad_accumulation_step = 1 + # if true, then we do not do an optimizer step. We are accumulating gradients + self.is_grad_accumulation_step = False + self.device = str(self.accelerator.device) + self.device_torch = self.accelerator.device + network_config = self.get_conf('network', None) + if network_config is not None: + self.network_config = NetworkConfig(**network_config) + else: + self.network_config = None + self.train_config = TrainConfig(**self.get_conf('train', {})) + model_config = self.get_conf('model', {}) + self.modules_being_trained: List[torch.nn.Module] = [] + + # update modelconfig dtype to match train + model_config['dtype'] = self.train_config.dtype + self.model_config = ModelConfig(**model_config) + + self.save_config = SaveConfig(**self.get_conf('save', {})) + self.sample_config = SampleConfig(**self.get_conf('sample', {})) + first_sample_config = self.get_conf('first_sample', None) + if first_sample_config is not None: + self.has_first_sample_requested = True + self.first_sample_config = SampleConfig(**first_sample_config) + else: + self.has_first_sample_requested = False + self.first_sample_config = self.sample_config + self.logging_config = LoggingConfig(**self.get_conf('logging', {})) + self.logger = create_logger(self.logging_config, config, self.save_root) + self.optimizer: torch.optim.Optimizer = None + self.lr_scheduler = None + self.data_loader: Union[DataLoader, None] = None + self.data_loader_reg: Union[DataLoader, None] = None + self.trigger_word = self.get_conf('trigger_word', None) + + self.guidance_config: Union[GuidanceConfig, None] = None + guidance_config_raw = self.get_conf('guidance', None) + if guidance_config_raw is not None: + self.guidance_config = GuidanceConfig(**guidance_config_raw) + + # store is all are cached. Allows us to not load vae if we don't need to + self.is_latents_cached = True + raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) + self.datasets = None + self.datasets_reg = None + self.dataset_configs: List[DatasetConfig] = [] + self.params = [] + + # add dataset text embedding cache to their config + if self.train_config.cache_text_embeddings: + for raw_dataset in raw_datasets: + raw_dataset['cache_text_embeddings'] = True + + if raw_datasets is not None and len(raw_datasets) > 0: + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + # handle trigger word per dataset + if dataset.trigger_word is None and self.trigger_word is not None: + dataset.trigger_word = self.trigger_word + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) + self.dataset_configs.append(dataset) + + self.is_caching_text_embeddings = any( + dataset.cache_text_embeddings for dataset in self.dataset_configs + ) + + self.embed_config = None + embedding_raw = self.get_conf('embedding', None) + if embedding_raw is not None: + self.embed_config = EmbeddingConfig(**embedding_raw) + + self.decorator_config: DecoratorConfig = None + decorator_raw = self.get_conf('decorator', None) + if decorator_raw is not None: + if not self.model_config.is_flux: + raise ValueError("Decorators are only supported for Flux models currently") + self.decorator_config = DecoratorConfig(**decorator_raw) + + # t2i adapter + self.adapter_config = None + adapter_raw = self.get_conf('adapter', None) + if adapter_raw is not None: + self.adapter_config = AdapterConfig(**adapter_raw) + # sdxl adapters end in _xl. Only full_adapter_xl for now + if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'): + self.adapter_config.adapter_type += '_xl' + + # to hold network if there is one + self.network: Union[Network, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None + self.embedding: Union[Embedding, None] = None + self.decorator: Union[Decorator, None] = None + + is_training_adapter = self.adapter_config is not None and self.adapter_config.train + + self.do_lorm = self.get_conf('do_lorm', False) + self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio') + self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25) + # 'ratio', 0.25) + + # get the device state preset based on what we are training + self.train_device_state_preset = get_train_sd_device_state_preset( + device=self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + cached_latents=self.is_latents_cached, + train_lora=self.network_config is not None, + train_adapter=is_training_adapter, + train_embedding=self.embed_config is not None, + train_decorator=self.decorator_config is not None, + train_refiner=self.train_config.train_refiner, + unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings, + require_grads=False # we ensure them later + ) + + self.get_params_device_state_preset = get_train_sd_device_state_preset( + device=self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + cached_latents=self.is_latents_cached, + train_lora=self.network_config is not None, + train_adapter=is_training_adapter, + train_embedding=self.embed_config is not None, + train_decorator=self.decorator_config is not None, + train_refiner=self.train_config.train_refiner, + unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings, + require_grads=True # We check for grads when getting params + ) + + # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) + self.is_fine_tuning = True + if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None: + self.is_fine_tuning = False + + self.named_lora = False + if self.embed_config is not None or is_training_adapter: + self.named_lora = True + self.snr_gos: Union[LearnableSNRGamma, None] = None + self.ema: ExponentialMovingAverage = None + + validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs) + + do_profiler = self.get_conf('torch_profiler', False) + self.torch_profiler = None if not do_profiler else torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + + self.current_boundary_index = 0 + self.steps_this_boundary = 0 + self.num_consecutive_oom = 0 + self.additional_logs = {} + + def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): + # override in subclass + return generate_image_config_list + + def sample(self, step=None, is_first=False): + if not self.accelerator.is_main_process: + return + flush() + sample_folder = os.path.join(self.save_root, 'samples') + gen_img_config_list = [] + + sample_config = self.first_sample_config if is_first else self.sample_config + start_seed = sample_config.seed + current_seed = start_seed + + test_image_paths = [] + if self.adapter_config is not None and self.adapter_config.test_img_path is not None: + test_image_path_list = self.adapter_config.test_img_path + # divide up images so they are evenly distributed across prompts + for i in range(len(sample_config.prompts)): + test_image_paths.append(test_image_path_list[i % len(test_image_path_list)]) + + for i in range(len(sample_config.prompts)): + if sample_config.walk_seed: + current_seed = start_seed + i + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + + filename = f"[time]_{step_num}_[count].{self.sample_config.ext}" + + output_path = os.path.join(sample_folder, filename) + + prompt = sample_config.prompts[i] + + # add embedding if there is one + # note: diffusers will automatically expand the trigger to the number of added tokens + # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, expand_token=True, add_if_not_present=False + ) + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, expand_token=True, add_if_not_present=False + ) + if self.trigger_word is not None: + prompt = self.sd.inject_trigger_into_prompt( + prompt, self.trigger_word, add_if_not_present=False + ) + + extra_args = {} + if self.adapter_config is not None and self.adapter_config.test_img_path is not None: + extra_args['adapter_image_path'] = test_image_paths[i] + + sample_item = sample_config.samples[i] + if sample_item.seed is not None: + current_seed = sample_item.seed + + gen_img_config_list.append(GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + width=sample_item.width, + height=sample_item.height, + negative_prompt=sample_item.neg, + seed=current_seed, + guidance_scale=sample_item.guidance_scale, + guidance_rescale=sample_config.guidance_rescale, + num_inference_steps=sample_item.sample_steps, + network_multiplier=sample_item.network_multiplier, + output_path=output_path, + output_ext=sample_config.ext, + adapter_conditioning_scale=sample_config.adapter_conditioning_scale, + refiner_start_at=sample_config.refiner_start_at, + extra_values=sample_config.extra_values, + logger=self.logger, + num_frames=sample_item.num_frames, + fps=sample_item.fps, + ctrl_img=sample_item.ctrl_img, + ctrl_idx=sample_item.ctrl_idx, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, + do_cfg_norm=sample_config.do_cfg_norm, + **extra_args + )) + + # post process + gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list) + + # if we have an ema, set it to validation mode + if self.ema is not None: + self.ema.eval() + + # let adapter know we are sampling + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + self.adapter.is_sampling = True + + # send to be generated + self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + self.adapter.is_sampling = False + + if self.ema is not None: + self.ema.train() + + def update_training_metadata(self): + o_dict = OrderedDict({ + "training_info": self.get_training_info() + }) + o_dict['ss_base_model_version'] = self.sd.get_base_model_version() + + # o_dict = add_base_model_info_to_meta( + # o_dict, + # is_v2=self.model_config.is_v2, + # is_xl=self.model_config.is_xl, + # ) + o_dict['ss_output_name'] = self.job.name + + if self.trigger_word is not None: + # just so auto1111 will pick it up + o_dict['ss_tag_frequency'] = { + f"1_{self.trigger_word}": { + f"{self.trigger_word}": 1 + } + } + + self.add_meta(o_dict) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def clean_up_saves(self): + if not self.accelerator.is_main_process: + return + # remove old saves + # get latest saved step + latest_item = None + if os.path.exists(self.save_root): + # pattern is {job_name}_{zero_filled_step} for both files and directories + pattern = f"{self.job.name}_*" + items = glob.glob(os.path.join(self.save_root, pattern)) + # Separate files and directories + safetensors_files = [f for f in items if f.endswith('.safetensors')] + pt_files = [f for f in items if f.endswith('.pt')] + directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')] + embed_files = [] + # do embedding files + if self.embed_config is not None: + embed_pattern = f"{self.embed_config.trigger}_*" + embed_items = glob.glob(os.path.join(self.save_root, embed_pattern)) + # will end in safetensors or pt + embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')] + + # check for critic files + critic_pattern = f"CRITIC_{self.job.name}_*" + critic_items = glob.glob(os.path.join(self.save_root, critic_pattern)) + + # Sort the lists by creation time if they are not empty + if safetensors_files: + safetensors_files.sort(key=os.path.getctime) + if pt_files: + pt_files.sort(key=os.path.getctime) + if directories: + directories.sort(key=os.path.getctime) + if embed_files: + embed_files.sort(key=os.path.getctime) + if critic_items: + critic_items.sort(key=os.path.getctime) + + # Combine and sort the lists + combined_items = safetensors_files + directories + pt_files + combined_items.sort(key=os.path.getctime) + + num_saves_to_keep = self.save_config.max_step_saves_to_keep + + if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'): + num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier + + # Use slicing with a check to avoid 'NoneType' error + safetensors_to_remove = safetensors_files[ + :-num_saves_to_keep] if safetensors_files else [] + pt_files_to_remove = pt_files[:-num_saves_to_keep] if pt_files else [] + directories_to_remove = directories[:-num_saves_to_keep] if directories else [] + embeddings_to_remove = embed_files[:-num_saves_to_keep] if embed_files else [] + critic_to_remove = critic_items[:-num_saves_to_keep] if critic_items else [] + + items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove + + # remove all but the latest max_step_saves_to_keep + # items_to_remove = combined_items[:-num_saves_to_keep] + + # remove duplicates + items_to_remove = list(dict.fromkeys(items_to_remove)) + + for item in items_to_remove: + print_acc(f"Removing old save: {item}") + if os.path.isdir(item): + shutil.rmtree(item) + else: + os.remove(item) + # see if a yaml file with same name exists + yaml_file = os.path.splitext(item)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) + if combined_items: + latest_item = combined_items[-1] + return latest_item + + def post_save_hook(self, save_path): + # override in subclass + pass + + def done_hook(self): + pass + + def end_step_hook(self): + pass + + def save(self, step=None): + if not self.accelerator.is_main_process: + return + flush() + if self.ema is not None: + # always save params as ema + self.ema.eval() + + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + self.last_save_step = step + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + filename = f'{self.job.name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + + save_meta = copy.deepcopy(self.meta) + # get extra meta + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + additional_save_meta = self.adapter.get_additional_save_metadata() + if additional_save_meta is not None: + for key, value in additional_save_meta.items(): + save_meta[key] = value + + # prepare meta + save_meta = get_meta_for_safetensors(save_meta, self.job.name) + if not self.is_fine_tuning and not self.train_config.merge_network_on_save: + if self.network is not None: + lora_name = self.job.name + if self.named_lora: + # add _lora to name + lora_name += '_LoRA' + + filename = f'{lora_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + prev_multiplier = self.network.multiplier + self.network.multiplier = 1.0 + + # if we are doing embedding training as well, add that + embedding_dict = self.embedding.state_dict() if self.embedding else None + self.network.save_weights( + file_path, + dtype=get_torch_dtype(self.save_config.dtype), + metadata=save_meta, + extra_state_dict=embedding_dict + ) + self.network.multiplier = prev_multiplier + # if we have an embedding as well, pair it with the network + + # even if added to lora, still save the trigger version + if self.embedding is not None: + emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors' + emb_file_path = os.path.join(self.save_root, emb_filename) + # for combo, above will get it + # set current step + self.embedding.step = self.step_num + # change filename to pt if that is set + if self.embed_config.save_format == "pt": + # replace extension + emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" + self.embedding.save(emb_file_path) + + if self.decorator is not None: + dec_filename = f'{self.job.name}{step_num}.safetensors' + dec_file_path = os.path.join(self.save_root, dec_filename) + decorator_state_dict = self.decorator.state_dict() + for key, value in decorator_state_dict.items(): + if isinstance(value, torch.Tensor): + decorator_state_dict[key] = value.clone().to('cpu', dtype=get_torch_dtype(self.save_config.dtype)) + save_file( + decorator_state_dict, + dec_file_path, + metadata=save_meta, + ) + + if self.adapter is not None and self.adapter_config.train: + adapter_name = self.job.name + if self.network_config is not None or self.embedding is not None: + # add _lora to name + if self.adapter_config.type == 't2i': + adapter_name += '_t2i' + elif self.adapter_config.type == 'control_net': + adapter_name += '_cn' + elif self.adapter_config.type == 'clip': + adapter_name += '_clip' + elif self.adapter_config.type.startswith('ip'): + adapter_name += '_ip' + else: + adapter_name += '_adapter' + + filename = f'{adapter_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + # save adapter + state_dict = self.adapter.state_dict() + if self.adapter_config.type == 't2i': + save_t2i_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) + elif self.adapter_config.type == 'control_net': + # save in diffusers format + name_or_path = file_path.replace('.safetensors', '') + # move it to the new dtype and cpu + orig_device = self.adapter.device + orig_dtype = self.adapter.dtype + self.adapter = self.adapter.to(torch.device('cpu'), dtype=get_torch_dtype(self.save_config.dtype)) + self.adapter.save_pretrained( + name_or_path, + dtype=get_torch_dtype(self.save_config.dtype), + safe_serialization=True + ) + meta_path = os.path.join(name_or_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(self.meta, f) + # move it back + self.adapter = self.adapter.to(orig_device, dtype=orig_dtype) + else: + direct_save = False + if self.adapter_config.train_only_image_encoder: + direct_save = True + elif isinstance(self.adapter, CustomAdapter): + direct_save = self.adapter.do_direct_save + save_ip_adapter_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype), + direct_save=direct_save + ) + else: + if self.network is not None and self.train_config.merge_network_on_save: + # merge the network weights into a full model and save that. + # torchao quantized weights can be force merged here (dequantize -> merge -> re-quantize) + # even though can_merge_in is False (kept False so sampling never merges). quanto and + # layer_offloading still cannot merge. + from toolkit.util.quantize import get_torchao_config + can_force_quantized_merge = ( + self.model_config.quantize and not self.model_config.layer_offloading + and get_torchao_config(self.model_config.qtype) is not None + ) + if not self.network.can_merge_in and not can_force_quantized_merge: + raise ValueError("Network cannot merge in weights. Cannot save full model.") + + print_acc("Merging network weights into full model for saving...") + + self.network.merge_in(merge_weight=self.train_config.merge_network_on_save_strength) + # reset weights to zero + self.network.reset_weights() + self.network.is_merged_in = False + + print_acc("Done merging network weights. Saving model...") + + if self.save_config.save_format == "diffusers": + # saving as a folder path + file_path = file_path.replace('.safetensors', '') + # convert it back to normal object + save_meta = parse_metadata_from_safetensors(save_meta) + + if self.sd.refiner_unet and self.train_config.train_refiner: + # save refiner + refiner_name = self.job.name + '_refiner' + filename = f'{refiner_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + self.sd.save_refiner( + file_path, + save_meta, + get_torch_dtype(self.save_config.dtype) + ) + if self.train_config.train_unet or self.train_config.train_text_encoder: + self.sd.save( + file_path, + save_meta, + get_torch_dtype(self.save_config.dtype) + ) + + # save learnable params as json if we have thim + if self.snr_gos: + json_data = { + 'offset_1': self.snr_gos.offset_1.item(), + 'offset_2': self.snr_gos.offset_2.item(), + 'scale': self.snr_gos.scale.item(), + 'gamma': self.snr_gos.gamma.item(), + } + path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json') + with open(path_to_save, 'w') as f: + json.dump(json_data, f, indent=4) + + print_acc(f"Saved checkpoint to {file_path}") + + # save optimizer + if self.optimizer is not None: + try: + filename = f'optimizer.pt' + file_path = os.path.join(self.save_root, filename) + try: + state_dict = unwrap_model(self.optimizer).state_dict() + except Exception as e: + state_dict = self.optimizer.state_dict() + torch.save(state_dict, file_path) + print_acc(f"Saved optimizer to {file_path}") + except Exception as e: + print_acc(e) + print_acc("Could not save optimizer") + + self.clean_up_saves() + self.post_save_hook(file_path) + + if self.ema is not None: + self.ema.train() + flush() + + # Called before the model is loaded + def hook_before_model_load(self): + # override in subclass + pass + + def hook_after_model_load(self): + # override in subclass + pass + + def hook_add_extra_train_params(self, params): + # override in subclass + return params + + def hook_before_train_loop(self): + if self.accelerator.is_main_process: + self.logger.start() + self.prepare_accelerator() + + def sample_step_hook(self, img_num, total_imgs): + pass + + def prepare_accelerator(self): + # set some config + self.accelerator.even_batches=False + + # # prepare all the models stuff for accelerator (hopefully we dont miss any) + self.sd.vae = self.accelerator.prepare(self.sd.vae) + if self.sd.unet is not None: + self.sd.unet = self.accelerator.prepare(self.sd.unet) + # todo always tdo it? + self.modules_being_trained.append(self.sd.unet) + if self.sd.text_encoder is not None and self.train_config.train_text_encoder: + if isinstance(self.sd.text_encoder, list): + self.sd.text_encoder = [self.accelerator.prepare(model) for model in self.sd.text_encoder] + self.modules_being_trained.extend(self.sd.text_encoder) + else: + self.sd.text_encoder = self.accelerator.prepare(self.sd.text_encoder) + self.modules_being_trained.append(self.sd.text_encoder) + if self.sd.refiner_unet is not None and self.train_config.train_refiner: + self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet) + self.modules_being_trained.append(self.sd.refiner_unet) + # todo, do we need to do the network or will "unet" get it? + if self.sd.network is not None: + self.sd.network = self.accelerator.prepare(self.sd.network) + self.modules_being_trained.append(self.sd.network) + if self.adapter is not None and self.adapter_config.train: + # todo adapters may not be a module. need to check + self.adapter = self.accelerator.prepare(self.adapter) + self.modules_being_trained.append(self.adapter) + + # prepare other things + self.optimizer = self.accelerator.prepare(self.optimizer) + if self.lr_scheduler is not None: + self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) + # self.data_loader = self.accelerator.prepare(self.data_loader) + # if self.data_loader_reg is not None: + # self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg) + + + def ensure_params_requires_grad(self, force=False): + if self.train_config.do_paramiter_swapping and not force: + # the optimizer will handle this if we are not forcing + return + for group in self.params: + for param in group['params']: + if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter + param.requires_grad_(True) + + def setup_ema(self): + if self.train_config.ema_config.use_ema: + # our params are in groups. We need them as a single iterable + params = [] + for group in self.optimizer.param_groups: + for param in group['params']: + params.append(param) + self.ema = ExponentialMovingAverage( + params, + decay=self.train_config.ema_config.ema_decay, + use_feedback=self.train_config.ema_config.use_feedback, + param_multiplier=self.train_config.ema_config.param_multiplier, + ) + + def before_dataset_load(self): + pass + + def get_params(self): + # you can extend this in subclass to get params + # otherwise params will be gathered through normal means + return None + + def hook_train_loop(self, batch): + # return loss + return 0.0 + + def hook_after_sd_init_before_load(self): + pass + + def get_latest_save_path(self, name=None, post=''): + if name == None: + name = self.job.name + # get latest saved step + latest_path = None + if os.path.exists(self.save_root): + # Define patterns for both files and directories + patterns = [ + f"{name}*{post}.safetensors", + f"{name}*{post}.pt", + f"{name}*{post}" + ] + # Search for both files and directories + paths = [] + for pattern in patterns: + paths.extend(glob.glob(os.path.join(self.save_root, pattern))) + + # Filter out non-existent paths and sort by creation time + if paths: + paths = [p for p in paths if os.path.exists(p)] + # remove false positives + if '_LoRA' not in name: + paths = [p for p in paths if '_LoRA' not in p] + if '_refiner' not in name: + paths = [p for p in paths if '_refiner' not in p] + if '_t2i' not in name: + paths = [p for p in paths if '_t2i' not in p] + if '_cn' not in name: + paths = [p for p in paths if '_cn' not in p] + + if len(paths) > 0: + latest_path = max(paths, key=os.path.getctime) + + if latest_path is None and self.network_config is not None and self.network_config.pretrained_lora_path is not None: + # set pretrained lora path as load path if we do not have a checkpoint to resume from + if os.path.exists(self.network_config.pretrained_lora_path): + latest_path = self.network_config.pretrained_lora_path + print_acc(f"Using pretrained lora path from config: {latest_path}") + else: + # no pretrained lora found + print_acc(f"Pretrained lora path from config does not exist: {self.network_config.pretrained_lora_path}") + + return latest_path + + def load_training_state_from_metadata(self, path): + if not self.accelerator.is_main_process: + return + if path is not None and self.network_config is not None and path == self.network_config.pretrained_lora_path: + # dont load metadata from pretrained lora + return + meta = None + # if path is folder, then it is diffusers + if os.path.isdir(path): + meta_path = os.path.join(path, 'aitk_meta.yaml') + # load it + if os.path.exists(meta_path): + with open(meta_path, 'r') as f: + meta = yaml.load(f, Loader=yaml.FullLoader) + else: + meta = load_metadata_from_safetensors(path) + # if 'training_info' in Orderdict keys + if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: + self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] + self.start_step = self.step_num + print_acc(f"Found step {self.step_num} in metadata, starting from there") + + def load_weights(self, path): + if self.network is not None: + extra_weights = self.network.load_weights(path) + self.load_training_state_from_metadata(path) + return extra_weights + else: + print_acc("load_weights not implemented for non-network models") + return None + + def apply_snr(self, seperated_loss, timesteps): + if self.train_config.learnable_snr_gos: + # add snr_gamma + seperated_loss = apply_learnable_snr_gos(seperated_loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001: + # add snr_gamma + seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + return seperated_loss + + def load_lorm(self): + latest_save_path = self.get_latest_save_path() + if latest_save_path is not None: + # hacky way to reload weights for now + # todo, do this + state_dict = load_file(latest_save_path, device=self.device) + self.sd.unet.load_state_dict(state_dict) + + meta = load_metadata_from_safetensors(latest_save_path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + if 'epoch' in meta['training_info']: + self.epoch_num = meta['training_info']['epoch'] + self.start_step = self.step_num + print_acc(f"Found step {self.step_num} in metadata, starting from there") + + # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch) + # sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) + # schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, ) + # timesteps = timesteps.to(self.device_torch, ) + # + # # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # step_indices = [t for t in timesteps] + # + # sigma = sigmas[step_indices].flatten() + # while len(sigma.shape) < n_dim: + # sigma = sigma.unsqueeze(-1) + # return sigma + + def load_additional_training_modules(self, params): + # override in subclass + return params + + def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype) + schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device) + timesteps = timesteps.to(self.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def get_optimal_noise(self, latents, dtype=torch.float32): + batch_num = latents.shape[0] + chunks = torch.chunk(latents, batch_num, dim=0) + noise_chunks = [] + for chunk in chunks: + noise_samples = [torch.randn_like(chunk, device=chunk.device, dtype=dtype) for _ in range(self.train_config.optimal_noise_pairing_samples)] + # find the one most similar to the chunk + lowest_loss = 999999999999 + best_noise = None + for noise in noise_samples: + loss = torch.nn.functional.mse_loss(chunk, noise) + if loss < lowest_loss: + lowest_loss = loss + best_noise = noise + noise_chunks.append(best_noise) + noise = torch.cat(noise_chunks, dim=0) + return noise + + def get_consistent_noise(self, latents, batch: 'DataLoaderBatchDTO', dtype=torch.float32): + batch_num = latents.shape[0] + chunks = torch.chunk(latents, batch_num, dim=0) + noise_chunks = [] + for idx, chunk in enumerate(chunks): + # get seed from path + file_item = batch.file_items[idx] + img_path = file_item.path + # add augmentors + if file_item.flip_x: + img_path += '_fx' + if file_item.flip_y: + img_path += '_fy' + seed = int(hashlib.md5(img_path.encode()).hexdigest(), 16) & 0xffffffff + generator = torch.Generator("cpu").manual_seed(seed) + noise_chunk = torch.randn(chunk.shape, generator=generator).to(chunk.device, dtype=dtype) + noise_chunks.append(noise_chunk) + noise = torch.cat(noise_chunks, dim=0).to(dtype=dtype) + return noise + + + def get_noise( + self, + latents, + batch_size, + dtype=torch.float32, + batch: 'DataLoaderBatchDTO' = None, + timestep=None, + ): + if self.train_config.optimal_noise_pairing_samples > 1: + noise = self.get_optimal_noise(latents, dtype=dtype) + elif self.train_config.force_consistent_noise: + if batch is None: + raise ValueError("Batch must be provided for consistent noise") + noise = self.get_consistent_noise(latents, batch, dtype=dtype) + else: + if hasattr(self.sd, 'get_latent_noise_from_latents'): + noise = self.sd.get_latent_noise_from_latents( + latents, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) + else: + # get noise + noise = self.sd.get_latent_noise( + height=latents.shape[2], + width=latents.shape[3], + num_channels=latents.shape[1], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + if self.train_config.blended_blur_noise: + noise = get_blended_blur_noise( + latents, noise, timestep + ) + + return noise + + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): + with torch.no_grad(): + with self.timer('prepare_prompt'): + prompts = batch.get_caption_list() + is_reg_list = batch.get_is_reg_list() + + is_any_reg = any([is_reg for is_reg in is_reg_list]) + + do_double = self.train_config.short_and_long_captions and not is_any_reg + + if self.train_config.short_and_long_captions and do_double: + # dont do this with regs. No point + + # double batch and add short captions to the end + prompts = prompts + batch.get_caption_short_list() + is_reg_list = is_reg_list + is_reg_list + if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet: + prompts = prompts + prompts + is_reg_list = is_reg_list + is_reg_list + + conditioned_prompts = [] + + for prompt, is_reg in zip(prompts, is_reg_list): + + # make sure the embedding is in the prompts + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + + # make sure trigger is in the prompts if not a regularization run + if self.trigger_word is not None: + prompt = self.sd.inject_trigger_into_prompt( + prompt, + trigger=self.trigger_word, + add_if_not_present=not is_reg, + ) + + if not is_reg and self.train_config.prompt_saturation_chance > 0.0: + # do random prompt saturation by expanding the prompt to hit at least 77 tokens + if random.random() < self.train_config.prompt_saturation_chance: + est_num_tokens = len(prompt.split(' ')) + if est_num_tokens < 77: + num_repeats = int(77 / est_num_tokens) + 1 + prompt = ', '.join([prompt] * num_repeats) + + + conditioned_prompts.append(prompt) + + with self.timer('prepare_latents'): + dtype = get_torch_dtype(self.train_config.dtype) + imgs = None + is_reg = any(batch.get_is_reg_list()) + if batch.tensor is not None: + imgs = batch.tensor + imgs = imgs.to(self.device_torch, dtype=dtype) + # dont adjust for regs. + if self.train_config.img_multiplier is not None and not is_reg: + # do it ad contrast + imgs = reduce_contrast(imgs, self.train_config.img_multiplier) + if batch.latents is not None: + latents = batch.latents.to(self.device_torch, dtype=dtype) + batch.latents = latents + else: + # normalize to + if self.train_config.standardize_images: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + target_mean_list = [0.0002, -0.1034, -0.1879] + target_std_list = [0.5436, 0.5116, 0.5033] + else: + target_mean_list = [-0.0739, -0.1597, -0.2380] + target_std_list = [0.5623, 0.5295, 0.5347] + # Mean: tensor([-0.0739, -0.1597, -0.2380]) + # Standard Deviation: tensor([0.5623, 0.5295, 0.5347]) + imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True) + imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True) + imgs = (imgs - imgs_channel_mean) / imgs_channel_std + target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) + target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) + # expand them to match dim + target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) + target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + imgs = imgs * target_std + target_mean + batch.tensor = imgs + + # show_tensors(imgs, 'imgs') + + latents = self.sd.encode_images(imgs) + batch.latents = latents + + if self.train_config.standardize_latents: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164] + target_std_list = [0.8979, 0.7505, 0.9150, 0.7451] + else: + target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929] + target_std_list = [0.8560, 0.9629, 0.7778, 0.6719] + + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) + latents_channel_std = latents.std(dim=(2, 3), keepdim=True) + latents = (latents - latents_channel_mean) / latents_channel_std + target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) + target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) + # expand them to match dim + target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) + target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + latents = latents * target_std + target_mean + batch.latents = latents + + # show_latents(latents, self.sd.vae, 'latents') + + + if batch.unconditional_tensor is not None and batch.unconditional_latents is None: + unconditional_imgs = batch.unconditional_tensor + unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype) + unconditional_latents = self.sd.encode_images(unconditional_imgs) + batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier + + unaugmented_latents = None + if self.train_config.loss_target == 'differential_noise': + # we determine noise from the differential of the latents + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + + with self.timer('prepare_scheduler'): + + batch_size = len(batch.file_items) + min_noise_steps = self.train_config.min_denoising_steps + max_noise_steps = self.train_config.max_denoising_steps + if self.model_config.refiner_name_or_path is not None: + # if we are not training the unet, then we are only doing refiner and do not need to double up + if self.train_config.train_unet: + max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = True + else: + min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = False + + num_train_timesteps = self.train_config.num_train_timesteps + + if self.train_config.noise_scheduler in ['custom_lcm']: + # we store this value on our custom one + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.train_timesteps, device=self.device_torch + ) + elif self.train_config.noise_scheduler in ['lcm']: + self.sd.noise_scheduler.set_timesteps( + num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps + ) + elif self.train_config.noise_scheduler == 'flowmatch': + linear_timesteps = any([ + self.train_config.linear_timesteps, + self.train_config.linear_timesteps2, + self.train_config.timestep_type == 'linear', + self.train_config.timestep_type in ['one_step', 'two_step', 'four_step', 'eight_step'], + ]) + + timestep_type = 'linear' if linear_timesteps else None + if timestep_type is None: + timestep_type = self.train_config.timestep_type + + if self.train_config.timestep_type == 'next_sample': + # simulate a sample + num_train_timesteps = self.train_config.next_sample_timesteps + timestep_type = 'shift' + + patch_size = 1 + if self.sd.is_flux or 'flex' in self.sd.arch: + # flux is a patch size of 1, but latents are divided by 2, so we need to double it + patch_size = 2 + elif hasattr(self.sd.unet, 'config') and hasattr(self.sd.unet.config, 'patch_size'): + patch_size = self.sd.unet.config.patch_size + + self.sd.noise_scheduler.set_train_timesteps( + num_train_timesteps, + device=self.device_torch, + timestep_type=timestep_type, + latents=latents, + patch_size=patch_size, + ) + else: + self.sd.noise_scheduler.set_timesteps( + num_train_timesteps, device=self.device_torch + ) + if self.sd.is_multistage: + with self.timer('adjust_multistage_timesteps'): + # get our current sample range + boundaries = [1] + self.sd.multistage_boundaries + boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1] + asc_timesteps = torch.flip(self.sd.noise_scheduler.timesteps, dims=[0]) + lo = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_max * 1000, device=asc_timesteps.device), right=False) + hi = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_min * 1000, device=asc_timesteps.device), right=True) + first_idx = (lo - 1).item() if hi > lo else 0 + last_idx = (hi - 1).item() if hi > lo else 999 + min_noise_steps = first_idx + max_noise_steps = last_idx + + # clip min max indicies + min_noise_steps = max(min_noise_steps, 0) + max_noise_steps = min(max_noise_steps, num_train_timesteps - 1) + + + with self.timer('prepare_timesteps_indices'): + + content_or_style = self.train_config.content_or_style + if is_reg: + content_or_style = self.train_config.content_or_style_reg + + if self.train_config.timestep_type in ['two_step', 'four_step', 'eight_step']: + if self.train_config.timestep_type == 'two_step': + indice_choices = [0, 499] + elif self.train_config.timestep_type == 'four_step': + indice_choices = [0, 250, 500, 750] + elif self.train_config.timestep_type == 'eight_step': + indice_choices = [0, 125, 250, 375, 500, 625, 750, 875] + timestep_indices = torch.tensor(random.choices(indice_choices, k=batch_size), device=self.device_torch) + timestep_indices = timestep_indices.long() + elif self.train_config.timestep_type == 'next_sample': + timestep_indices = torch.randint( + 0, + num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step + (batch_size,), + device=self.device_torch + ) + timestep_indices = timestep_indices.long() + elif self.train_config.timestep_type == 'one_step': + timestep_indices = torch.zeros((batch_size,), device=self.device_torch, dtype=torch.long) + elif content_or_style in ['style', 'content']: + # this is from diffusers training code + # Cubic sampling for favoring later or earlier timesteps + # For more details about why cubic sampling is used for content / structure, + # refer to section 3.4 of https://arxiv.org/abs/2302.08453 + + # for content / structure, it is best to favor earlier timesteps + # for style, it is best to favor later timesteps + + orig_timesteps = torch.rand((batch_size,), device=latents.device) + + if content_or_style == 'content': + timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps + elif content_or_style == 'style': + timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps + + timestep_indices = value_map( + timestep_indices, + 0, + self.train_config.num_train_timesteps - 1, + min_noise_steps, + max_noise_steps + ) + timestep_indices = timestep_indices.long().clamp( + min_noise_steps, + max_noise_steps + ) + + elif content_or_style == 'balanced': + if min_noise_steps == max_noise_steps: + timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps + else: + # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here + min_idx = min_noise_steps + 1 + max_idx = max_noise_steps - 1 + if self.train_config.noise_scheduler == 'flowmatch': + # flowmatch uses indices, so we need to use indices + min_idx = min_noise_steps + max_idx = max_noise_steps + timestep_indices = torch.randint( + min_idx, + max_idx, + (batch_size,), + device=self.device_torch + ) + timestep_indices = timestep_indices.long() + else: + raise ValueError(f"Unknown content_or_style {content_or_style}") + with self.timer('convert_timestep_indices_to_timesteps'): + # convert the timestep_indices to a timestep + timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + + with self.timer('prepare_noise'): + # get noise + noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps) + + # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents + # this will negate any noise offsets + if self.train_config.dynamic_noise_offset and not is_reg: + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2 + # subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel + noise = noise + latents_channel_mean + + if self.train_config.loss_target == 'differential_noise': + differential = latents - unaugmented_latents + # add noise to differential + # noise = noise + differential + noise = noise + (differential * 0.5) + # noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max()) + latents = unaugmented_latents + + noise_multiplier = self.train_config.noise_multiplier + + s = (noise.shape[0], noise.shape[1], 1, 1) + if len(noise.shape) == 5: + # if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame + s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1) + + noise = noise * noise_multiplier + + if self.train_config.do_signal_correction_noise: + batch_noise = latents.clone().to(noise.device, dtype=noise.dtype) + scn_scale = torch.randn( + batch_noise.shape[0], batch_noise.shape[1], 1, 1, + device=batch_noise.device, + dtype=batch_noise.dtype + ) * self.train_config.signal_correction_noise_scale + batch_noise = batch_noise * scn_scale + noise = noise + batch_noise + + if self.train_config.do_batch_noise_correction: + if latents.shape[0] == 1: + # if we only have a batch size of 1, then we cant do batch noise correction, so we skip it + print_acc("Skipping batch noise correction because batch size is 1, increase batch size and num_repeats to use this feature") + else: + # shuffle tensors ensuring that no tensor is in the same position as before + batch_noise = latents.clone().roll(shifts=torch.randint(1, latents.shape[0], (1,)).item(), dims=0).to(noise.device, dtype=noise.dtype) + batch_noise_scale = torch.randn( + batch_noise.shape[0], batch_noise.shape[1], 1, 1, + device=batch_noise.device, + dtype=batch_noise.dtype + ) * self.train_config.batch_noise_correction_scale + batch_noise = batch_noise * batch_noise_scale + noise = noise + batch_noise + + if self.train_config.random_noise_shift > 0.0: + # get random noise -1 to 1 + noise_shift = torch.randn( + batch_size, latents.shape[1], 1, 1, + device=noise.device, + dtype=noise.dtype + ) * self.train_config.random_noise_shift + # add to noise + noise += noise_shift + + if self.train_config.random_noise_multiplier > 0.0: + sigma = self.train_config.random_noise_multiplier + noise_multiplier = torch.exp(torch.randn(s, device=noise.device, dtype=noise.dtype) * sigma) + noise = noise * noise_multiplier + with self.timer('make_noisy_latents'): + + latent_multiplier = self.train_config.latent_multiplier + + # handle adaptive scaling mased on std + if self.train_config.adaptive_scaling_factor: + std = latents.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + latent_multiplier = normalizer + + latents = latents * latent_multiplier + + if self.train_config.do_blank_stabilization: + # zero out latents with blank prompts + blank_latent = torch.zeros_like(latents) + for i, prompt in enumerate(conditioned_prompts): + if prompt.strip() == '': + latents[i] = blank_latent[i] + + batch.latents = latents + + # normalize latents to a mean of 0 and an std of 1 + # mean_zero_latents = latents - latents.mean() + # latents = mean_zero_latents / mean_zero_latents.std() + + if batch.unconditional_latents is not None: + batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier + + + noisy_latents = self.sd.add_noise(latents, noise, timesteps) + + # determine scaled noise + # todo do we need to scale this or does it always predict full intensity + # noise = noisy_latents - latents + + # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77 + if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented': + sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + # add it to the batch + batch.sigmas = sigmas + # todo is this for sdxl? find out where this came from originally + # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + + def double_up_tensor(tensor: torch.Tensor): + if tensor is None: + return None + return torch.cat([tensor, tensor], dim=0) + + if do_double: + if self.model_config.refiner_name_or_path: + # apply refiner double up + refiner_timesteps = torch.randint( + max_noise_steps, + self.train_config.max_denoising_steps, + (batch_size,), + device=self.device_torch + ) + refiner_timesteps = refiner_timesteps.long() + # add our new timesteps on to end + timesteps = torch.cat([timesteps, refiner_timesteps], dim=0) + + refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps) + noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0) + + else: + # just double it + noisy_latents = double_up_tensor(noisy_latents) + timesteps = double_up_tensor(timesteps) + + noise = double_up_tensor(noise) + # prompts are already updated above + imgs = double_up_tensor(imgs) + batch.mask_tensor = double_up_tensor(batch.mask_tensor) + batch.control_tensor = double_up_tensor(batch.control_tensor) + + noisy_latent_multiplier = self.train_config.noisy_latent_multiplier + + if noisy_latent_multiplier != 1.0: + noisy_latents = noisy_latents * noisy_latent_multiplier + + # remove grads for these + noisy_latents.requires_grad = False + noisy_latents = noisy_latents.detach() + noise.requires_grad = False + noise = noise.detach() + + return noisy_latents, noise, timesteps, conditioned_prompts, imgs + + def setup_adapter(self): + # t2i adapter + is_t2i = self.adapter_config.type == 't2i' + is_control_net = self.adapter_config.type == 'control_net' + if self.adapter_config.type == 't2i': + suffix = 't2i' + elif self.adapter_config.type == 'control_net': + suffix = 'cn' + elif self.adapter_config.type == 'clip': + suffix = 'clip' + elif self.adapter_config.type == 'reference': + suffix = 'ref' + elif self.adapter_config.type.startswith('ip'): + suffix = 'ip' + else: + suffix = 'adapter' + adapter_name = self.name + if self.network_config is not None: + adapter_name = f"{adapter_name}_{suffix}" + latest_save_path = self.get_latest_save_path(adapter_name) + + if latest_save_path is not None and not self.adapter_config.train: + # the save path is for something else since we are not training + latest_save_path = self.adapter_config.name_or_path + + dtype = get_torch_dtype(self.train_config.dtype) + if is_t2i: + # if we do not have a last save path and we have a name_or_path, + # load from that + if latest_save_path is None and self.adapter_config.name_or_path is not None: + self.adapter = T2IAdapter.from_pretrained( + self.adapter_config.name_or_path, + torch_dtype=get_torch_dtype(self.train_config.dtype), + varient="fp16", + # use_safetensors=True, + ) + else: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) + elif is_control_net: + if self.adapter_config.name_or_path is None: + raise ValueError("ControlNet requires a name_or_path to load from currently") + load_from_path = self.adapter_config.name_or_path + if latest_save_path is not None: + load_from_path = latest_save_path + self.adapter = ControlNetModel.from_pretrained( + load_from_path, + torch_dtype=get_torch_dtype(self.train_config.dtype), + ) + elif self.adapter_config.type == 'clip': + self.adapter = ClipVisionAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + elif self.adapter_config.type == 'reference': + self.adapter = ReferenceAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + elif self.adapter_config.type.startswith('ip'): + self.adapter = IPAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + if self.train_config.gradient_checkpointing: + self.adapter.enable_gradient_checkpointing() + else: + self.adapter = CustomAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + train_config=self.train_config, + ) + self.adapter.to(self.device_torch, dtype=dtype) + if latest_save_path is not None and not is_control_net: + # load adapter from path + print_acc(f"Loading adapter from {latest_save_path}") + if is_t2i: + loaded_state_dict = load_t2i_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + elif self.adapter_config.type.startswith('ip'): + # ip adapter + loaded_state_dict = load_ip_adapter_model( + latest_save_path, + self.device, + dtype=dtype, + direct_load=self.adapter_config.train_only_image_encoder + ) + self.adapter.load_state_dict(loaded_state_dict) + else: + # custom adapter + loaded_state_dict = load_custom_adapter_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + if latest_save_path is not None and self.adapter_config.train: + self.load_training_state_from_metadata(latest_save_path) + # set trainable params + self.sd.adapter = self.adapter + + def run(self): + # torch.autograd.set_detect_anomaly(True) + # run base process run + BaseTrainProcess.run(self) + params = [] + + ### HOOK ### + self.hook_before_model_load() + model_config_to_load = copy.deepcopy(self.model_config) + + if self.is_fine_tuning or self.train_config.merge_network_on_save: + # get the latest checkpoint + # check to see if we have a latest save + latest_save_path = self.get_latest_save_path() + + if latest_save_path is not None: + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + model_config_to_load.name_or_path = latest_save_path + self.load_training_state_from_metadata(latest_save_path) + + ModelClass = get_model_class(self.model_config) + # if the model class has get_train_scheduler static method + if hasattr(ModelClass, 'get_train_scheduler'): + sampler = ModelClass.get_train_scheduler() + else: + # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + arch=arch, + ) + + if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: + previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner') + if previous_refiner_save is not None: + model_config_to_load.refiner_name_or_path = previous_refiner_save + self.load_training_state_from_metadata(previous_refiner_save) + + self.sd = ModelClass( + # todo handle single gpu and multi gpu here + # device=self.device, + device=self.accelerator.device, + model_config=model_config_to_load, + dtype=self.train_config.dtype, + custom_pipeline=self.custom_pipeline, + noise_scheduler=sampler, + ) + + self.hook_after_sd_init_before_load() + # run base sd process run + self.sd.load_model() + + self.sd.add_after_sample_image_hook(self.sample_step_hook) + + dtype = get_torch_dtype(self.train_config.dtype) + + # model is loaded from BaseSDProcess + unet = self.sd.unet + vae = self.sd.vae + tokenizer = self.sd.tokenizer + text_encoder = self.sd.text_encoder + noise_scheduler = self.sd.noise_scheduler + + if self.train_config.xformers: + vae.enable_xformers_memory_efficient_attention() + unet.enable_xformers_memory_efficient_attention() + if isinstance(text_encoder, list): + for te in text_encoder: + # if it has it + if hasattr(te, 'enable_xformers_memory_efficient_attention'): + te.enable_xformers_memory_efficient_attention() + + if self.train_config.attention_backend != 'native': + if hasattr(vae, 'set_attention_backend'): + vae.set_attention_backend(self.train_config.attention_backend) + if hasattr(unet, 'set_attention_backend'): + unet.set_attention_backend(self.train_config.attention_backend) + if isinstance(text_encoder, list): + for te in text_encoder: + if hasattr(te, 'set_attention_backend'): + te.set_attention_backend(self.train_config.attention_backend) + else: + if hasattr(text_encoder, 'set_attention_backend'): + text_encoder.set_attention_backend(self.train_config.attention_backend) + if self.train_config.sdp: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + + # # check if we have sage and is flux + # if self.sd.is_flux: + # # try_to_activate_sage_attn() + # try: + # from sageattention import sageattn + # from toolkit.models.flux_sage_attn import FluxSageAttnProcessor2_0 + # model: FluxTransformer2DModel = self.sd.unet + # # enable sage attention on each block + # for block in model.transformer_blocks: + # processor = FluxSageAttnProcessor2_0() + # block.attn.set_processor(processor) + # for block in model.single_transformer_blocks: + # processor = FluxSageAttnProcessor2_0() + # block.attn.set_processor(processor) + + # except ImportError: + # print_acc("sage attention is not installed. Using SDP instead") + + if self.train_config.gradient_checkpointing: + # if has method enable_gradient_checkpointing + if hasattr(unet, 'enable_gradient_checkpointing'): + unet.enable_gradient_checkpointing() + elif hasattr(unet, 'gradient_checkpointing'): + unet.gradient_checkpointing = True + else: + print("Gradient checkpointing not supported on this model") + if isinstance(text_encoder, list): + for te in text_encoder: + if hasattr(te, 'enable_gradient_checkpointing'): + te.enable_gradient_checkpointing() + if hasattr(te, "gradient_checkpointing_enable"): + te.gradient_checkpointing_enable() + else: + if hasattr(text_encoder, 'enable_gradient_checkpointing'): + text_encoder.enable_gradient_checkpointing() + if hasattr(text_encoder, "gradient_checkpointing_enable"): + text_encoder.gradient_checkpointing_enable() + + if self.sd.refiner_unet is not None: + self.sd.refiner_unet.to(self.device_torch, dtype=dtype) + self.sd.refiner_unet.requires_grad_(False) + self.sd.refiner_unet.eval() + if self.train_config.xformers: + self.sd.refiner_unet.enable_xformers_memory_efficient_attention() + if self.train_config.gradient_checkpointing: + self.sd.refiner_unet.enable_gradient_checkpointing() + + if isinstance(text_encoder, list): + for te in text_encoder: + te.requires_grad_(False) + te.eval() + else: + text_encoder.requires_grad_(False) + text_encoder.eval() + unet.to(self.device_torch, dtype=dtype) + unet.requires_grad_(False) + unet.eval() + vae = vae.to(torch.device('cpu'), dtype=dtype) + vae.requires_grad_(False) + vae.eval() + if self.train_config.learnable_snr_gos: + self.snr_gos = LearnableSNRGamma( + self.sd.noise_scheduler, device=self.device_torch + ) + # check to see if previous settings exist + path_to_load = os.path.join(self.save_root, 'learnable_snr.json') + if os.path.exists(path_to_load): + with open(path_to_load, 'r') as f: + json_data = json.load(f) + if 'offset' in json_data: + # legacy + self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch) + else: + self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch) + self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch) + self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) + self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) + + self.hook_after_model_load() + flush() + if not self.is_fine_tuning: + if self.network_config is not None: + # TODO should we completely switch to LycorisSpecialNetwork? + network_kwargs = self.network_config.network_kwargs + is_lycoris = False + is_lorm = self.network_config.type.lower() == 'lorm' + # default to LoCON if there are any conv layers or if it is named + NetworkClass = LoRASpecialNetwork + if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': + NetworkClass = LycorisSpecialNetwork + is_lycoris = True + + if is_lorm: + network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains + network_kwargs['parameter_threshold'] = lorm_parameter_threshold + network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE + + # if is_lycoris: + # preset = PRESET['full'] + # NetworkClass.apply_preset(preset) + + if hasattr(self.sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + + self.network = NetworkClass( + text_encoder=text_encoder, + unet=self.sd.get_model_to_train(), + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=is_lorm, + is_lorm=is_lorm, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, + base_model=self.sd, + **network_kwargs + ) + + + # todo switch everything to proper mixed precision like this + self.network.force_to(self.device_torch, dtype=torch.float32) + # give network to sd so it can use it + self.sd.network = self.network + self.network._update_torch_multiplier() + + self.network.apply_to( + text_encoder, + unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + + # we cannot merge in if quantized or offloading. note: torchao quantized weights can + # still be force merged at save time for the merge-and-reset method (see save logic), + # but we keep can_merge_in False here so sampling never merges in/out. + if self.model_config.quantize or self.model_config.layer_offloading: + # todo find a way around this + self.network.can_merge_in = False + + if is_lorm: + self.network.is_lorm = True + # make sure it is on the right device + self.sd.unet.to(self.sd.device, dtype=dtype) + original_unet_param_count = count_parameters(self.sd.unet) + self.network.setup_lorm() + new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction() + + print_lorm_extract_details( + start_num_params=original_unet_param_count, + end_num_params=new_unet_param_count, + num_replaced=len(self.network.get_all_modules()), + ) + + self.network.prepare_grad_etc(text_encoder, unet) + flush() + + # LyCORIS doesnt have default_lr + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.network.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.network.prepare_optimizer_params( + **config + ) + + params += params_net + + if self.train_config.gradient_checkpointing: + self.network.enable_gradient_checkpointing() + + lora_name = self.name + # need to adapt name so they are not mixed up + if self.named_lora: + lora_name = f"{lora_name}_LoRA" + + latest_save_path = self.get_latest_save_path(lora_name) + extra_weights = None + if latest_save_path is not None and not self.train_config.merge_network_on_save: + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"Loading from {latest_save_path}") + extra_weights = self.load_weights(latest_save_path) + self.network.multiplier = 1.0 + + if self.network_config.layer_offloading: + MemoryManager.attach( + self.network, + self.device_torch + ) + + if self.embed_config is not None: + # we are doing embedding training as well + self.embedding = Embedding( + sd=self.sd, + embed_config=self.embed_config + ) + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + if self.embedding.step > 1: + self.step_num = self.embedding.step + self.start_step = self.step_num + + # self.step_num = self.embedding.step + # self.start_step = self.step_num + params.append({ + 'params': list(self.embedding.get_trainable_params()), + 'lr': self.train_config.embedding_lr + }) + + flush() + + if self.decorator_config is not None: + self.decorator = Decorator( + num_tokens=self.decorator_config.num_tokens, + token_size=4096 # t5xxl hidden size for flux + ) + latest_save_path = self.get_latest_save_path() + # load last saved weights + if latest_save_path is not None: + state_dict = load_file(latest_save_path) + self.decorator.load_state_dict(state_dict) + self.load_training_state_from_metadata(latest_save_path) + + params.append({ + 'params': list(self.decorator.parameters()), + 'lr': self.train_config.lr + }) + + # give it to the sd network + self.sd.decorator = self.decorator + self.decorator.to(self.device_torch, dtype=torch.float32) + self.decorator.train() + + flush() + + if self.adapter_config is not None: + self.setup_adapter() + if self.adapter_config.train: + + if isinstance(self.adapter, IPAdapter): + # we have custom LR groups for IPAdapter + adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr) + for group in adapter_param_groups: + params.append(group) + else: + # set trainable params + params.append({ + 'params': list(self.adapter.parameters()), + 'lr': self.train_config.adapter_lr + }) + + if self.train_config.gradient_checkpointing: + self.adapter.enable_gradient_checkpointing() + flush() + + params = self.load_additional_training_modules(params) + + else: # no network, embedding or adapter + # set the device state preset before getting params + self.sd.set_device_state(self.get_params_device_state_preset) + + # params = self.get_params() + if len(params) == 0: + # will only return savable weights and ones with grad + params = self.sd.prepare_optimizer_params( + unet=self.train_config.train_unet, + text_encoder=self.train_config.train_text_encoder, + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr, + refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None, + refiner_lr=self.train_config.refiner_lr, + ) + # we may be using it for prompt injections + if self.adapter_config is not None and self.adapter is None: + self.setup_adapter() + flush() + + ### HOOK ### + params = self.hook_add_extra_train_params(params) + self.params = params + # self.params = [] + + # for param in params: + # if isinstance(param, dict): + # self.params += param['params'] + # else: + # self.params.append(param) + + if self.train_config.start_step is not None: + self.step_num = self.train_config.start_step + self.start_step = self.step_num + + optimizer_type = self.train_config.optimizer.lower() + + # esure params require grad + self.ensure_params_requires_grad(force=True) + optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr, + optimizer_params=self.train_config.optimizer_params) + self.optimizer = optimizer + + # set it to do paramiter swapping + if self.train_config.do_paramiter_swapping: + # only works for adafactor, but it should have thrown an error prior to this otherwise + self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor) + + # check if it exists + optimizer_state_filename = f'optimizer.pt' + optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename) + if os.path.exists(optimizer_state_file_path): + # try to load + # previous param groups + # previous_params = copy.deepcopy(optimizer.param_groups) + previous_lrs = [] + for group in optimizer.param_groups: + previous_lrs.append(group['lr']) + + load_optimizer = True + if self.network is not None: + if self.network.did_change_weights: + # do not load optimizer if the network changed, it will result in + # a double state that will oom. + load_optimizer = False + + if load_optimizer: + try: + print_acc(f"Loading optimizer state from {optimizer_state_file_path}") + optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + flush() + except Exception as e: + print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") + print_acc(e) + + # update the optimizer LR from the params + print_acc(f"Updating optimizer LR from params") + if len(previous_lrs) > 0: + for i, group in enumerate(optimizer.param_groups): + group['lr'] = previous_lrs[i] + group['initial_lr'] = previous_lrs[i] + + # Update the learning rates if they changed + # optimizer.param_groups = previous_params + + # set up the ema now that the optimizer (and its params) are ready + self.setup_ema() + + lr_scheduler_params = self.train_config.lr_scheduler_params + + # make sure it had bare minimum + if 'max_iterations' not in lr_scheduler_params: + lr_scheduler_params['total_iters'] = self.train_config.steps + + lr_scheduler = get_lr_scheduler( + self.train_config.lr_scheduler, + optimizer, + **lr_scheduler_params + ) + self.lr_scheduler = lr_scheduler + + ### HOOk ### + self.before_dataset_load() + # load datasets if passed in the root process + if self.datasets is not None: + self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) + if self.datasets_reg is not None: + self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, + self.sd) + + flush() + self.last_save_step = self.step_num + ### HOOK ### + self.hook_before_train_loop() + + # ============================================================ + # COMPILE + # + # compile: true + # -> whole-model torch.compile + # + # compile: true + # block_compile: true + # -> block-level compilation + # ============================================================ + if self.model_config.compile: + compiled_refs = [] # (block_list, index, original_block) for rollback on failure + try: + inner_unet_check = unwrap_model(self.sd.unet) + is_unet_offloaded = hasattr(inner_unet_check, '_memory_manager') + + text_encoder = getattr(self.sd, "text_encoder", None) + text_encoder_check = unwrap_model(text_encoder) if text_encoder is not None else None + is_te_offloaded = hasattr(text_encoder_check, '_memory_manager') if text_encoder_check is not None else False + + is_unet_quantized = getattr(self.model_config, 'quantize', False) + is_quantized = is_unet_quantized or getattr(self.model_config, 'quantize_te', False) + + if not is_unet_offloaded: + self.sd.unet.to(self.device_torch) + + cache_size_limit = getattr(self.model_config, 'cache_size_limit', None) + user_set_cache_limit = cache_size_limit is not None + if user_set_cache_limit: + torch._dynamo.config.cache_size_limit = cache_size_limit + torch._dynamo.config.suppress_errors = False + + compile_mode = getattr(self.model_config, 'compile_mode', 'default') + compile_dynamic = getattr(self.model_config, 'compile_dynamic', True) + compile_fullgraph = getattr(self.model_config, 'compile_fullgraph', False) + block_compile = getattr(self.model_config, 'block_compile', False) + + # quantized + offloaded unet is incompatible with fullgraph; force it off + if is_unet_quantized and is_unet_offloaded and compile_fullgraph: + print_acc( + "Quantized offloaded Transformer detected: fullgraph=True is incompatible, " + "switching to fullgraph=False." + ) + compile_fullgraph = False + + cache_info = "" + # ==================================================== + # BLOCK COMPILE + # ==================================================== + if block_compile: + BLOCK_LIST_ATTRS = self.sd.get_transformer_block_names() + + if BLOCK_LIST_ATTRS is None or len(BLOCK_LIST_ATTRS) == 0: + BLOCK_LIST_ATTRS = [ + 'layers', + 'transformer_blocks', + 'single_transformer_blocks', + 'double_stream_blocks', + 'single_stream_blocks', + 'double_blocks', + 'single_blocks', + 'blocks', + ] + inner_unet = unwrap_model(self.sd.unet) + + compiled_block_count = 0 + + for attr_name in BLOCK_LIST_ATTRS: + # attr_name may be a dotted path for models that nest their + # blocks (e.g. hidream_o1's "model.language_model.layers"). + block_list = inner_unet + for part in attr_name.split('.'): + block_list = getattr(block_list, part, None) + if block_list is None: + break + + if block_list is None: + continue + + if not hasattr(block_list, '__len__'): + continue + + for i, block in enumerate(block_list): + if not isinstance(block, torch.nn.Module): + continue + + if hasattr(block, '_hf_hook'): + continue + + compiled_refs.append((block_list, i, block)) + block_list[i] = torch.compile( + block, + mode=compile_mode, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + compiled_block_count += 1 + + if compiled_block_count > 0: + if user_set_cache_limit: + auto_cache_limit = max(cache_size_limit, compiled_block_count * 2) + if auto_cache_limit != cache_size_limit: + torch._dynamo.config.cache_size_limit = auto_cache_limit + cache_info = f", cache_size_limit={auto_cache_limit} (auto)" + else: + cache_info = f", cache_size_limit={cache_size_limit}" + else: + auto_cache_limit = compiled_block_count * 2 + torch._dynamo.config.cache_size_limit = auto_cache_limit + cache_info = f", cache_size_limit={auto_cache_limit} (auto)" + print_acc( + f"Compiled {compiled_block_count} transformer block(s) " + f"with torch.compile (mode='{compile_mode}', fullgraph={compile_fullgraph}, dynamic={compile_dynamic}{cache_info})." + ) + print_acc("The first forward pass will be slow during compile. This is normal.") + print_acc("If you are experiencing issues, disable block_compile.") + else: + print_acc( + f"No individual transformer blocks found; " + f"falling back to whole-model torch.compile " + f"(mode='{compile_mode}', fullgraph={compile_fullgraph}, dynamic={compile_dynamic}{cache_info})." + ) + print_acc("The first forward pass will hang for a while. This is normal.") + + if is_unet_quantized and not is_unet_offloaded and compile_fullgraph: + print_acc( + "Quantized model detected: fullgraph=True is incompatible " + "for whole-model compile, switching to fullgraph=False." + ) + compile_fullgraph = False + + if compile_mode == 'default': + self.sd.unet = torch.compile( + self.sd.unet, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + else: + self.sd.unet = torch.compile( + self.sd.unet, + mode=compile_mode, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + + # ==================================================== + # WHOLE MODEL COMPILE + # ==================================================== + else: + print_acc("Compiling model with torch.compile (whole-model compile).") + print_acc("The first forward pass will hang for a while. This is normal.") + + print_acc( + f"Using torch.compile settings: " + f"mode={compile_mode}, " + f"dynamic={compile_dynamic}, " + f"fullgraph={compile_fullgraph}{cache_info}" + ) + + if compile_fullgraph: + print_acc( + "fullgraph=True is incompatible with whole-model compile, " + "switching to fullgraph=False." + ) + compile_fullgraph = False + + if compile_mode == 'default': + self.sd.unet = torch.compile( + self.sd.unet, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + else: + self.sd.unet = torch.compile( + self.sd.unet, + mode=compile_mode, + dynamic=compile_dynamic, + fullgraph=compile_fullgraph, + ) + + if not is_unet_offloaded: + # once compiled, dynamo guards hold weakrefs to the params; + # .to() on quantized params requires swap_tensors, which fails + # on tensors with weakrefs. The model stays on device anyway, + # so make .to() a no-op. + unet_module = self.sd.unet + unet_module.to = lambda *args, **kwargs: unet_module + + except Exception as e: + # undo any block-level compiles that happened before the failure, + # so "continuing without compilation" is actually true + if len(compiled_refs) > 0: + for block_list, i, original_block in compiled_refs: + block_list[i] = original_block + + if 'triton' in str(e).lower(): + print_acc("WARNING: compile is disabled.") + print_acc("Triton is not available or not working on this system.") + print_acc("Install a working 'triton' package to use compile.") + print_acc("Continuing without compilation.") + else: + print_acc(f"Failed to compile model: {e}") + print_acc("Continuing without compilation") + + if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: + print_acc("Generating first sample from first sample config") + self.sample(0, is_first=True) + + # sample first + if self.train_config.skip_first_sample or self.train_config.disable_sampling: + print_acc("Skipping first sample due to config setting") + elif self.step_num <= 1 or self.train_config.force_first_sample: + print_acc("Generating baseline samples before training") + self.sample(self.step_num) + + if self.accelerator.is_local_main_process: + self.progress_bar = ToolkitProgressBar( + total=self.train_config.steps, + desc=self.job.name, + leave=True, + initial=self.step_num, + iterable=range(0, self.train_config.steps), + ) + self.progress_bar.pause() + else: + self.progress_bar = None + + if self.data_loader is not None: + dataloader = self.data_loader + dataloader_iterator = iter(dataloader) + else: + dataloader = None + dataloader_iterator = None + + if self.data_loader_reg is not None: + dataloader_reg = self.data_loader_reg + dataloader_iterator_reg = iter(dataloader_reg) + else: + dataloader_reg = None + dataloader_iterator_reg = None + + # zero any gradients + optimizer.zero_grad() + + self.lr_scheduler.step(self.step_num) + + self.sd.set_device_state(self.train_device_state_preset) + flush() + # self.step_num = 0 + + # print_acc(f"Compiling Model") + # torch.compile(self.sd.unet, dynamic=True) + + # make sure all params require grad + self.ensure_params_requires_grad(force=True) + + + ################################################################### + # TRAIN LOOP + ################################################################### + + + start_step_num = self.step_num + did_first_flush = False + flush_next = False + for step in range(start_step_num, self.train_config.steps): + if self.train_config.do_paramiter_swapping: + self.optimizer.optimizer.swap_paramiters() + self.timer.start('train_loop') + if flush_next: + flush() + flush_next = False + if self.train_config.do_random_cfg: + self.train_config.do_cfg = True + self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale) + self.step_num = step + # default to true so various things can turn it off + self.is_grad_accumulation_step = True + if self.train_config.free_u: + self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) + if self.progress_bar is not None: + self.progress_bar.unpause() + with torch.no_grad(): + # if is even step and we have a reg dataset, use that + # todo improve this logic to send one of each through if we can buckets and batch size might be an issue + is_reg_step = False + is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0 + is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0 + if self.train_config.disable_sampling: + is_sample_step = False + + batch_list = [] + + for b in range(self.train_config.gradient_accumulation): + # keep track to alternate on an accumulation step for reg + batch_step = step + # don't do a reg step on sample or save steps as we dont want to normalize on those + if batch_step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: + try: + with self.timer('get_batch:reg'): + batch = next(dataloader_iterator_reg) + except StopIteration: + with self.timer('reset_batch:reg'): + # hit the end of an epoch, reset + if self.progress_bar is not None: + self.progress_bar.pause() + dataloader_iterator_reg = iter(dataloader_reg) + trigger_dataloader_setup_epoch(dataloader_reg) + + with self.timer('get_batch:reg'): + batch = next(dataloader_iterator_reg) + if self.progress_bar is not None: + self.progress_bar.unpause() + is_reg_step = True + elif dataloader is not None: + try: + with self.timer('get_batch'): + batch = next(dataloader_iterator) + except StopIteration: + with self.timer('reset_batch'): + # hit the end of an epoch, reset + if self.progress_bar is not None: + self.progress_bar.pause() + dataloader_iterator = iter(dataloader) + trigger_dataloader_setup_epoch(dataloader) + self.epoch_num += 1 + if self.train_config.gradient_accumulation_steps == -1: + # if we are accumulating for an entire epoch, trigger a step + self.is_grad_accumulation_step = False + self.grad_accumulation_step = 0 + with self.timer('get_batch'): + batch = next(dataloader_iterator) + if self.progress_bar is not None: + self.progress_bar.unpause() + else: + batch = None + batch_list.append(batch) + batch_step += 1 + + # setup accumulation + if self.train_config.gradient_accumulation_steps == -1: + # epoch is handling the accumulation, dont touch it + pass + else: + # determine if we are accumulating or not + # since optimizer step happens in the loop, we trigger it a step early + # since we cannot reprocess it before them + optimizer_step_at = self.train_config.gradient_accumulation_steps + is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at + self.is_grad_accumulation_step = not is_optimizer_step + if is_optimizer_step: + self.grad_accumulation_step = 0 + + # flush() + ### HOOK ### + if self.torch_profiler is not None: + self.torch_profiler.start() + did_oom = False + loss_dict = None + try: + with self.accelerator.accumulate(self.modules_being_trained): + loss_dict = self.hook_train_loop(batch_list) + except torch.cuda.OutOfMemoryError: + did_oom = True + except RuntimeError as e: + if "CUDA out of memory" in str(e): + did_oom = True + else: + raise # not an OOM; surface real errors + if did_oom: + self.num_consecutive_oom += 1 + if self.num_consecutive_oom > 3: + raise RuntimeError("OOM during training step 3 times in a row, aborting training") + optimizer.zero_grad(set_to_none=True) + flush() + torch.cuda.ipc_collect() + # skip this step and keep going + print_acc("") + print_acc("################################################") + print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #") + print_acc("################################################") + print_acc("") + else: + self.num_consecutive_oom = 0 + if self.torch_profiler is not None: + torch.cuda.synchronize() # Make sure all CUDA ops are done + self.torch_profiler.stop() + + print("\n==== Profile Results ====") + print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) + self.timer.stop('train_loop') + if not did_first_flush: + flush() + did_first_flush = True + # flush() + # setup the networks to gradient checkpointing and everything works + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + with torch.no_grad(): + # torch.cuda.empty_cache() + # if optimizer has get_lrs method, then use it + learning_rate = 0.0 + if not did_oom and loss_dict is not None: + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif hasattr(optimizer, 'get_learning_rates'): + learning_rate = optimizer.get_learning_rates()[0] + elif self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" + + if self.progress_bar is not None: + self.progress_bar.set_postfix_str(prog_bar_string) + + # if the batch is a DataLoaderBatchDTO, then we need to clean it up + if isinstance(batch, DataLoaderBatchDTO): + with self.timer('batch_cleanup'): + batch.cleanup() + + # don't do on first step + if self.step_num != self.start_step: + if is_sample_step or is_save_step: + self.accelerator.wait_for_everyone() + + if is_save_step: + self.accelerator + # print above the progress bar + if self.progress_bar is not None: + self.progress_bar.pause() + print_acc(f"\nSaving at step {self.step_num}") + self.save(self.step_num) + self.ensure_params_requires_grad() + # clear any grads + optimizer.zero_grad() + flush() + flush_next = True + if self.progress_bar is not None: + self.progress_bar.unpause() + + if is_sample_step: + if self.progress_bar is not None: + self.progress_bar.pause() + flush() + # print above the progress bar + if self.train_config.free_u: + self.sd.pipeline.disable_freeu() + self.sample(self.step_num) + if self.train_config.unload_text_encoder: + # make sure the text encoder is unloaded + self.sd.text_encoder_to('cpu') + flush() + + self.ensure_params_requires_grad() + if self.progress_bar is not None: + self.progress_bar.unpause() + + if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: + if self.progress_bar is not None: + self.progress_bar.pause() + with self.timer('log_to_tensorboard'): + # log to tensorboard + if self.accelerator.is_main_process: + if self.writer is not None: + if loss_dict is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + if self.progress_bar is not None: + self.progress_bar.unpause() + + if self.accelerator.is_main_process: + # log to logger + self.logger.log({ + 'learning_rate': learning_rate, + }) + if loss_dict is not None: + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + if self.additional_logs is not None: + for key, value in self.additional_logs.items(): + self.logger.log({ + key: value, + }) + self.additional_logs = {} + elif self.logging_config.log_every is None: + if self.accelerator.is_main_process: + # log every step + self.logger.log({ + 'learning_rate': learning_rate, + }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + if self.additional_logs is not None: + for key, value in self.additional_logs.items(): + self.logger.log({ + key: value, + }) + self.additional_logs = {} + + + if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0: + if self.progress_bar is not None: + self.progress_bar.pause() + # print the timers and clear them + self.timer.print() + self.timer.reset() + if self.progress_bar is not None: + self.progress_bar.unpause() + + # commit log + if self.accelerator.is_main_process: + with self.timer('commit_logger'): + self.logger.commit(step=self.step_num) + + # sets progress bar to match out step + if self.progress_bar is not None: + self.progress_bar.update(step - self.progress_bar.n) + + ############################# + # End of step + ############################# + + # update various steps + self.step_num = step + 1 + self.grad_accumulation_step += 1 + self.end_step_hook() + + + ################################################################### + ## END TRAIN LOOP + ################################################################### + self.accelerator.wait_for_everyone() + if self.progress_bar is not None: + self.progress_bar.close() + if self.train_config.free_u: + self.sd.pipeline.disable_freeu() + if self.accelerator.is_main_process: + self.save() + if not self.train_config.disable_sampling: + self.sample(self.step_num) + self.logger.commit(step=self.step_num) + print_acc("") + if self.accelerator.is_main_process: + self.logger.finish() + self.accelerator.end_training() + + if self.accelerator.is_main_process: + # push to hub + if self.save_config.push_to_hub: + if("HF_TOKEN" not in os.environ): + interpreter_login(new_session=False, write_permission=True) + self.push_to_hub( + repo_id=self.save_config.hf_repo_id, + private=self.save_config.hf_private + ) + del ( + self.sd, + unet, + noise_scheduler, + optimizer, + self.network, + tokenizer, + text_encoder, + ) + + flush() + self.done_hook() + + def push_to_hub( + self, + repo_id: str, + private: bool = False, + ): + if not self.accelerator.is_main_process: + return + readme_content = self._generate_readme(repo_id) + readme_path = os.path.join(self.save_root, "README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(readme_content) + + api = HfApi() + + api.create_repo( + repo_id, + private=private, + exist_ok=True + ) + + api.upload_folder( + repo_id=repo_id, + folder_path=self.save_root, + ignore_patterns=["*.yaml", "*.pt"], + repo_type="model", + ) + + + def _generate_readme(self, repo_id: str) -> str: + """Generates the content of the README.md file.""" + + # Gather model info + base_model = self.model_config.name_or_path + instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None + if base_model == "black-forest-labs/FLUX.1-schnell": + license = "apache-2.0" + elif base_model == "black-forest-labs/FLUX.1-dev": + license = "other" + license_name = "flux-1-dev-non-commercial-license" + license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" + else: + license = "creativeml-openrail-m" + tags = [ + "text-to-image", + ] + if self.model_config.is_xl: + tags.append("stable-diffusion-xl") + if self.model_config.is_flux: + tags.append("flux") + if self.model_config.is_lumina2: + tags.append("lumina2") + if self.model_config.is_v3: + tags.append("sd3") + if self.network_config: + tags.extend( + [ + "lora", + "diffusers", + "template:sd-lora", + "ai-toolkit", + ] + ) + + # Generate the widget section + widgets = [] + sample_image_paths = [] + samples_dir = os.path.join(self.save_root, "samples") + if os.path.isdir(samples_dir): + for filename in os.listdir(samples_dir): + #The filenames are structured as 1724085406830__00000500_0.jpg + #So here we capture the 2nd part (steps) and 3rd (index the matches the prompt) + match = re.search(r"__(\d+)_(\d+)\.jpg$", filename) + if match: + steps, index = int(match.group(1)), int(match.group(2)) + #Here we only care about uploading the latest samples, the match with the # of steps + if steps == self.train_config.steps: + sample_image_paths.append((index, f"samples/{filename}")) + + # Sort by numeric index + sample_image_paths.sort(key=lambda x: x[0]) + + # Create widgets matching prompt with the index + for i, prompt in enumerate(self.sample_config.prompts): + if i < len(sample_image_paths): + # Associate prompts with sample image paths based on the extracted index + _, image_path = sample_image_paths[i] + widgets.append( + { + "text": prompt, + "output": { + "url": image_path + }, + } + ) + dtype = "torch.bfloat16" if self.model_config.is_flux else "torch.float16" + # Construct the README content + readme_content = f"""--- +tags: +{yaml.dump(tags, indent=4).strip()} +{"widget:" if os.path.isdir(samples_dir) else ""} +{yaml.dump(widgets, indent=4).strip() if widgets else ""} +base_model: {base_model} +{"instance_prompt: " + instance_prompt if instance_prompt else ""} +license: {license} +{'license_name: ' + license_name if license == "other" else ""} +{'license_link: ' + license_link if license == "other" else ""} +--- + +# {self.job.name} +Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) + + +## Trigger words + +{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."} + +## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc. + +Weights for this model are available in Safetensors format. + +[Download](/{repo_id}/tree/main) them in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='{self.job.name}.safetensors') +image = pipeline('{instance_prompt if not widgets else self.sample_config.prompts[0]}').images[0] +image.save("my_image.png") +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +""" + return readme_content diff --git a/ai-toolkit/jobs/process/BaseTrainProcess.py b/ai-toolkit/jobs/process/BaseTrainProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..d1885de23930ab372f26efa2b2281c9a9332f4fe --- /dev/null +++ b/ai-toolkit/jobs/process/BaseTrainProcess.py @@ -0,0 +1,79 @@ +import random +from datetime import datetime +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Union + +import torch +import yaml + +from jobs.process.BaseProcess import BaseProcess + +if TYPE_CHECKING: + from jobs import TrainJob, BaseJob, ExtensionJob + from torch.utils.tensorboard import SummaryWriter + from tqdm import tqdm + + +class BaseTrainProcess(BaseProcess): + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.writer: 'SummaryWriter' + self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] + self.progress_bar: 'tqdm' = None + + self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None) + # if training seed is set, use it + if self.training_seed is not None: + torch.manual_seed(self.training_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.training_seed) + random.seed(self.training_seed) + + self.progress_bar = None + self.writer = None + self.training_folder = self.get_conf('training_folder', + self.job.training_folder if hasattr(self.job, 'training_folder') else None) + self.save_root = os.path.join(self.training_folder, self.name) + self.step = 0 + self.first_step = 0 + self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None) + self.setup_tensorboard() + self.save_training_config() + + def run(self): + super().run() + # implement in child class + # be sure to call super().run() first + pass + + # def print(self, message, **kwargs): + def print(self, *args): + if self.progress_bar is not None: + self.progress_bar.write(' '.join(map(str, args))) + self.progress_bar.update() + else: + print(*args) + + def setup_tensorboard(self): + if self.log_dir: + from torch.utils.tensorboard import SummaryWriter + now = datetime.now() + time_str = now.strftime('%Y%m%d-%H%M%S') + summary_name = f"{self.name}_{time_str}" + summary_dir = os.path.join(self.log_dir, summary_name) + self.writer = SummaryWriter(summary_dir) + + def save_training_config(self): + os.makedirs(self.save_root, exist_ok=True) + save_dif = os.path.join(self.save_root, f'config.yaml') + with open(save_dif, 'w') as f: + yaml.dump(self.job.raw_config, f) diff --git a/ai-toolkit/jobs/process/ExtractLoconProcess.py b/ai-toolkit/jobs/process/ExtractLoconProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dac5edd7bcc5fb959fb4a3717bfa975d1264cc --- /dev/null +++ b/ai-toolkit/jobs/process/ExtractLoconProcess.py @@ -0,0 +1,68 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + +mode_dict = { + 'fixed': { + 'linear': 64, + 'conv': 32, + 'type': int + }, + 'threshold': { + 'linear': 0, + 'conv': 0, + 'type': float + }, + 'ratio': { + 'linear': 0.5, + 'conv': 0.5, + 'type': float + }, + 'quantile': { + 'linear': 0.5, + 'conv': 0.5, + 'type': float + } +} + + +class ExtractLoconProcess(BaseExtractProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.mode = self.get_conf('mode', 'fixed') + self.use_sparse_bias = self.get_conf('use_sparse_bias', False) + self.sparsity = self.get_conf('sparsity', 0.98) + self.disable_cp = self.get_conf('disable_cp', False) + + # set modes + if self.mode not in list(mode_dict.keys()): + raise ValueError(f"Unknown mode: {self.mode}") + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) + + def run(self): + super().run() + print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}") + + state_dict, extract_diff_meta = extract_diff( + self.job.model_base, + self.job.model_extract, + self.mode, + self.linear_param, + self.conv_param, + self.job.device, + self.use_sparse_bias, + self.sparsity, + not self.disable_cp, + extract_unet=self.extract_unet, + extract_text_encoder=self.extract_text_encoder + ) + + self.add_meta(extract_diff_meta) + self.save(state_dict) + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" + return super().get_output_path(prefix, suffix) + diff --git a/ai-toolkit/jobs/process/ExtractLoraProcess.py b/ai-toolkit/jobs/process/ExtractLoraProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..76f0cc942b0c6d76139223851965e643dfb31376 --- /dev/null +++ b/ai-toolkit/jobs/process/ExtractLoraProcess.py @@ -0,0 +1,73 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + + +mode_dict = { + 'fixed': { + 'linear': 4, + 'conv': 0, + 'type': int + }, + 'threshold': { + 'linear': 0, + 'conv': 0, + 'type': float + }, + 'ratio': { + 'linear': 0.5, + 'conv': 0, + 'type': float + }, + 'quantile': { + 'linear': 0.5, + 'conv': 0, + 'type': float + } +} + +CLAMP_QUANTILE = 0.99 +MIN_DIFF = 1e-6 + + +class ExtractLoraProcess(BaseExtractProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.mode = self.get_conf('mode', 'fixed') + + # set modes + if self.mode not in list(mode_dict.keys()): + raise ValueError(f"Unknown mode: {self.mode}") + self.linear = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) + self.use_sparse_bias = self.get_conf('use_sparse_bias', False) + self.sparsity = self.get_conf('sparsity', 0.98) + + def run(self): + super().run() + print(f"Running process: {self.mode}, dim: {self.dim}") + + state_dict, extract_diff_meta = extract_diff( + self.job.model_base, + self.job.model_extract, + self.mode, + self.linear_param, + self.conv_param, + self.job.device, + self.use_sparse_bias, + self.sparsity, + small_conv=False, + linear_only=self.conv_param > 0.0000000001, + extract_unet=self.extract_unet, + extract_text_encoder=self.extract_text_encoder + ) + + self.add_meta(extract_diff_meta) + self.save(state_dict) + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.dim}" + return super().get_output_path(prefix, suffix) diff --git a/ai-toolkit/jobs/process/GenerateProcess.py b/ai-toolkit/jobs/process/GenerateProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..6f67a322f7beaf84366cb3ff907f53376c207887 --- /dev/null +++ b/ai-toolkit/jobs/process/GenerateProcess.py @@ -0,0 +1,173 @@ +import gc +import os +from collections import OrderedDict +from typing import ForwardRef, List, Optional, Union + +import torch +from safetensors.torch import save_file, load_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ + add_base_model_info_to_meta +from toolkit.sampler import get_sampler +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +import random + +from toolkit.util.get_model import get_model_class + + +class GenerateConfig: + + def __init__(self, **kwargs): + self.prompts: List[str] + self.sampler = kwargs.get('sampler', 'ddpm') + self.width = kwargs.get('width', 512) + self.height = kwargs.get('height', 512) + self.size_list: Union[List[int], None] = kwargs.get('size_list', None) + self.neg = kwargs.get('neg', '') + self.seed = kwargs.get('seed', -1) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.prompt_2 = kwargs.get('prompt_2', None) + self.neg_2 = kwargs.get('neg_2', None) + self.prompts = kwargs.get('prompts', None) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.compile = kwargs.get('compile', False) + self.ext = kwargs.get('ext', 'png') + self.prompt_file = kwargs.get('prompt_file', False) + self.num_repeats = kwargs.get('num_repeats', 1) + self.prompts_in_file = self.prompts + if self.prompts is None: + raise ValueError("Prompts must be set") + if isinstance(self.prompts, str): + if os.path.exists(self.prompts): + with open(self.prompts, 'r', encoding='utf-8') as f: + self.prompts_in_file = f.read().splitlines() + self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0] + else: + raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts") + + self.random_prompts = kwargs.get('random_prompts', False) + self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1) + self.max_images = kwargs.get('max_images', 10000) + + if self.random_prompts: + self.prompts = [] + for i in range(self.max_images): + num_prompts = random.randint(1, self.max_random_per_prompt) + prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)] + self.prompts.append(", ".join(prompt_list)) + else: + self.prompts = self.prompts_in_file + + if kwargs.get('shuffle', False): + # shuffle the prompts + random.shuffle(self.prompts) + + +class GenerateProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + sd: StableDiffusion + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.device = self.get_conf('device', self.job.device) + self.generate_config = GenerateConfig(**self.get_conf('generate', required=True)) + self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16')) + + self.progress_bar = None + + ModelClass = get_model_class(self.model_config) + # if the model class has get_train_scheduler static method + if hasattr(ModelClass, 'get_train_scheduler'): + sampler = ModelClass.get_train_scheduler() + else: + # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + arch=arch, + ) + self.sd = ModelClass( + device=self.device, + model_config=self.model_config, + dtype=self.model_config.dtype, + noise_scheduler=sampler, + ) + + print(f"Using device {self.device}") + + def clean_prompt(self, prompt: str): + # remove any non alpha numeric characters or ,'" from prompt + return ''.join(e for e in prompt if e.isalnum() or e in ", '\"") + + def run(self): + with torch.no_grad(): + super().run() + print("Loading model...") + self.sd.load_model() + self.sd.pipeline.to(self.device, self.torch_dtype) + + print("Compiling model...") + # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) + if self.generate_config.compile: + self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead") + + print(f"Generating {len(self.generate_config.prompts)} images") + # build prompt image configs + prompt_image_configs = [] + for _ in range(self.generate_config.num_repeats): + for prompt in self.generate_config.prompts: + # remove -- + prompt = prompt.replace('--', '').strip() + width = self.generate_config.width + height = self.generate_config.height + # prompt = self.clean_prompt(prompt) + + if self.generate_config.size_list is not None: + # randomly select a size + width, height = random.choice(self.generate_config.size_list) + + prompt_image_configs.append(GenerateImageConfig( + prompt=prompt, + prompt_2=self.generate_config.prompt_2, + width=width, + height=height, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + negative_prompt=self.generate_config.neg, + negative_prompt_2=self.generate_config.neg_2, + seed=self.generate_config.seed, + guidance_rescale=self.generate_config.guidance_rescale, + output_ext=self.generate_config.ext, + output_folder=self.output_folder, + add_prompt_file=self.generate_config.prompt_file + )) + # generate images + self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) + + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/ai-toolkit/jobs/process/MergeLoconProcess.py b/ai-toolkit/jobs/process/MergeLoconProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..00c70cd2abdbc894f7b00c6cbf51a3dcfcc95531 --- /dev/null +++ b/ai-toolkit/jobs/process/MergeLoconProcess.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + + +class MergeLoconProcess(BaseExtractProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + super().run() + new_state_dict = {} + raise NotImplementedError("This is not implemented yet") + + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" + return super().get_output_path(prefix, suffix) + diff --git a/ai-toolkit/jobs/process/ModRescaleLoraProcess.py b/ai-toolkit/jobs/process/ModRescaleLoraProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb7436098f95ed774c7f1febc1b8bd7c0791981 --- /dev/null +++ b/ai-toolkit/jobs/process/ModRescaleLoraProcess.py @@ -0,0 +1,104 @@ +import gc +import os +from collections import OrderedDict +from typing import ForwardRef + +import torch +from safetensors.torch import save_file, load_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ + add_base_model_info_to_meta +from toolkit.train_tools import get_torch_dtype + + +class ModRescaleLoraProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.progress_bar: ForwardRef('tqdm') = None + self.input_path = self.get_conf('input_path', required=True) + self.output_path = self.get_conf('output_path', required=True) + self.replace_meta = self.get_conf('replace_meta', default=False) + self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype) + self.current_weight = self.get_conf('current_weight', required=True, as_type=float) + self.target_weight = self.get_conf('target_weight', required=True, as_type=float) + self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down + self.is_xl = self.get_conf('is_xl', default=False, as_type=bool) + self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool) + + self.progress_bar = None + + def run(self): + super().run() + source_state_dict = load_file(self.input_path) + source_meta = load_metadata_from_safetensors(self.input_path) + + if self.replace_meta: + self.meta.update( + add_base_model_info_to_meta( + self.meta, + is_xl=self.is_xl, + is_v2=self.is_v2, + ) + ) + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + else: + save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + new_state_dict = OrderedDict() + + for key in list(source_state_dict.keys()): + v = source_state_dict[key] + v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32')) + + # all loras have an alpha, up weight and down weight + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha", + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight", + # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight", + # we can rescale by adjusting the alpha or the up weights, or the up and down weights + # I assume doing both up and down would be best all around, but I'm not sure + # some locons also have mid weights, we will leave those alone for now, will work without them + + # when adjusting alpha, it is used to calculate the multiplier in a lora module + # - scale = alpha / lora_dim + # - output = layer_out + lora_up_out * multiplier * scale + total_module_scale = torch.tensor(self.current_weight / self.target_weight) \ + .to("cpu", dtype=get_torch_dtype('fp32')) + num_modules_layers = 2 # up and down + up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ + .to("cpu", dtype=get_torch_dtype('fp32')) + # only update alpha + if self.scale_target == 'alpha' and key.endswith('.alpha'): + v = v * total_module_scale + if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): + # would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN + v = v * up_down_scale + v = v.detach().clone().to("cpu").to(self.save_dtype) + new_state_dict[key] = v + + save_meta = add_model_hash_to_meta(new_state_dict, save_meta) + save_file(new_state_dict, self.output_path, save_meta) + + # cleanup incase there are other jobs + del new_state_dict + del source_state_dict + del source_meta + + torch.cuda.empty_cache() + gc.collect() + + print(f"Saved to {self.output_path}") diff --git a/ai-toolkit/jobs/process/TrainESRGANProcess.py b/ai-toolkit/jobs/process/TrainESRGANProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff3a69d89260396232f6085d161db3afe26b668 --- /dev/null +++ b/ai-toolkit/jobs/process/TrainESRGANProcess.py @@ -0,0 +1,657 @@ +import copy +import glob +import os +import time +from collections import OrderedDict +from typing import List, Optional + +from PIL import Image +from PIL.ImageOps import exif_transpose + +from toolkit.basic import flush +from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys +from safetensors.torch import save_file, load_file +from torch.utils.data import DataLoader, ConcatDataset +import torch +from torch import nn +from torchvision.transforms import transforms + +from jobs.process import BaseTrainProcess +from toolkit.data_loader import AugmentedImageDataset +from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.style import get_style_model_and_losses +from toolkit.train_tools import get_torch_dtype +from diffusers import AutoencoderKL +from tqdm import tqdm +import time +import numpy as np +from .models.vgg19_critic import Critic + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + # transforms.Normalize([0.5], [0.5]), + ] +) + + +class TrainESRGANProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.data_loader = None + self.model: ESRGAN = None + self.device = self.get_conf('device', self.job.device) + self.pretrained_path = self.get_conf('pretrained_path', 'None') + self.datasets_objects = self.get_conf('datasets', required=True) + self.batch_size = self.get_conf('batch_size', 1, as_type=int) + self.resolution = self.get_conf('resolution', 256, as_type=int) + self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) + self.sample_every = self.get_conf('sample_every', None) + self.optimizer_type = self.get_conf('optimizer', 'adam') + self.epochs = self.get_conf('epochs', None, as_type=int) + self.max_steps = self.get_conf('max_steps', None, as_type=int) + self.save_every = self.get_conf('save_every', None) + self.upscale_sample = self.get_conf('upscale_sample', 4) + self.dtype = self.get_conf('dtype', 'float32') + self.sample_sources = self.get_conf('sample_sources', None) + self.log_every = self.get_conf('log_every', 100, as_type=int) + self.style_weight = self.get_conf('style_weight', 0, as_type=float) + self.content_weight = self.get_conf('content_weight', 0, as_type=float) + self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.zoom = self.get_conf('zoom', 4, as_type=int) + self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.optimizer_params = self.get_conf('optimizer_params', {}) + self.augmentations = self.get_conf('augmentations', {}) + self.torch_dtype = get_torch_dtype(self.dtype) + if self.torch_dtype == torch.bfloat16: + self.esrgan_dtype = torch.float32 + else: + self.esrgan_dtype = torch.float32 + + self.vgg_19 = None + self.style_weight_scalers = [] + self.content_weight_scalers = [] + + # throw error if zoom if not divisible by 2 + if self.zoom % 2 != 0: + raise ValueError('zoom must be divisible by 2') + + self.step_num = 0 + self.epoch_num = 0 + + self.use_critic = self.get_conf('use_critic', False, as_type=bool) + self.critic = None + + if self.use_critic: + self.critic = Critic( + device=self.device, + dtype=self.dtype, + process=self, + **self.get_conf('critic', {}) # pass any other params + ) + + if self.sample_every is not None and self.sample_sources is None: + raise ValueError('sample_every is specified but sample_sources is not') + + if self.epochs is None and self.max_steps is None: + raise ValueError('epochs or max_steps must be specified') + + self.data_loaders = [] + # check datasets + assert isinstance(self.datasets_objects, list) + for dataset in self.datasets_objects: + if 'path' not in dataset: + raise ValueError('dataset must have a path') + # check if is dir + if not os.path.isdir(dataset['path']): + raise ValueError(f"dataset path does is not a directory: {dataset['path']}") + + # make training folder + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + self._pattern_loss = None + + # build augmentation transforms + aug_transforms = [] + + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.datasets_objects: + print(f" - Dataset: {dataset['path']}") + ds = copy.copy(dataset) + ds['resolution'] = self.resolution + + if 'augmentations' not in ds: + ds['augmentations'] = self.augmentations + + # add the resize down augmentation + ds['augmentations'] = [{ + 'method': 'Resize', + 'params': { + 'width': int(self.resolution // self.zoom), + 'height': int(self.resolution // self.zoom), + # downscale interpolation, string will be evaluated + 'interpolation': 'cv2.INTER_AREA' + } + }] + ds['augmentations'] + + image_dataset = AugmentedImageDataset(ds) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=6 + ) + + def setup_vgg19(self): + if self.vgg_19 is None: + self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( + single_target=True, + device=self.device, + output_layer_name='pool_4', + dtype=self.torch_dtype + ) + self.vgg_19.to(self.device, dtype=self.torch_dtype) + self.vgg_19.requires_grad_(False) + + # we run random noise through first to get layer scalers to normalize the loss per layer + # bs of 2 because we run pred and target through stacked + noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) + self.vgg_19(noise) + for style_loss in self.style_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(style_loss.loss).item() + self.style_weight_scalers.append(scaler) + for content_loss in self.content_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(content_loss.loss).item() + # if is nan, set to 1 + if scaler != scaler: + scaler = 1 + print(f"Warning: content loss scaler is nan, setting to 1") + self.content_weight_scalers.append(scaler) + + self.print(f"Style weight scalers: {self.style_weight_scalers}") + self.print(f"Content weight scalers: {self.content_weight_scalers}") + + def get_style_loss(self): + if self.style_weight > 0: + # scale all losses with loss scalers + loss = torch.sum( + torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_content_loss(self): + if self.content_weight > 0: + # scale all losses with loss scalers + loss = torch.sum(torch.stack( + [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_mse_loss(self, pred, target): + if self.mse_weight > 0: + loss_fn = nn.MSELoss() + loss = loss_fn(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_tv_loss(self, pred, target): + if self.tv_weight > 0: + get_tv_loss = ComparativeTotalVariation() + loss = get_tv_loss(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_pattern_loss(self, pred, target): + if self._pattern_loss is None: + self._pattern_loss = PatternLoss( + pattern_size=self.zoom, + dtype=self.torch_dtype + ).to(self.device, dtype=self.torch_dtype) + self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype) + loss = torch.mean(self._pattern_loss(pred, target)) + return loss + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + # filename = f'{self.job.name}{step_num}.safetensors' + filename = f'{self.job.name}{step_num}.pth' + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # state_dict = self.model.state_dict() + + # state has the original state dict keys so we can save what we started from + save_state_dict = self.model.state_dict() + + for key in list(save_state_dict.keys()): + v = save_state_dict[key] + v = v.detach().clone().to("cpu").to(torch.float32) + save_state_dict[key] = v + + # most things wont use safetensors, save as torch + # save_file(save_state_dict, os.path.join(self.save_root, filename), save_meta) + torch.save(save_state_dict, os.path.join(self.save_root, filename)) + + self.print(f"Saved to {os.path.join(self.save_root, filename)}") + + if self.use_critic: + self.critic.save(step) + + def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + batch_sample_folder = os.path.join(self.save_root, 'samples_batch') + + batch_targets = None + batch_inputs = None + if batch is not None and not os.path.exists(batch_sample_folder): + os.makedirs(batch_sample_folder, exist_ok=True) + + self.model.eval() + + def process_and_save(img, target_img, save_path): + img = img.to(self.device, dtype=self.esrgan_dtype) + output = self.model(img) + # output = (output / 2 + 0.5).clamp(0, 1) + output = output.clamp(0, 1) + img = img.clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + # convert to pillow image + output = Image.fromarray((output * 255).astype(np.uint8)) + img = Image.fromarray((img * 255).astype(np.uint8)) + + if isinstance(target_img, torch.Tensor): + # convert to pil + target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + target_img = Image.fromarray((target_img * 255).astype(np.uint8)) + + # upscale to size * self.upscale_sample while maintaining pixels + output = output.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) + img = img.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) + + width, height = output.size + + # stack input image and decoded image + target_image = target_img.resize((width, height)) + output = output.resize((width, height)) + img = img.resize((width, height)) + + output_img = Image.new('RGB', (width * 3, height)) + + output_img.paste(img, (0, 0)) + output_img.paste(output, (width, 0)) + output_img.paste(target_image, (width * 2, 0)) + + output_img.save(save_path) + + with torch.no_grad(): + for i, img_url in enumerate(self.sample_sources): + img = exif_transpose(Image.open(img_url)) + img = img.convert('RGB') + # crop if not square + if img.width != img.height: + min_dim = min(img.width, img.height) + img = img.crop((0, 0, min_dim, min_dim)) + # resize + img = img.resize((self.resolution * self.zoom, self.resolution * self.zoom), resample=Image.BICUBIC) + + target_image = img + # downscale the image input + img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC) + + # downscale the image input + + img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype) + img = img + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" + process_and_save(img, target_image, os.path.join(sample_folder, filename)) + + if batch is not None: + batch_targets = batch[0].detach() + batch_inputs = batch[1].detach() + batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0) + batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0) + + for i in range(len(batch_inputs)): + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" + process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) + + self.model.train() + + def load_model(self): + state_dict = None + path_to_load = self.pretrained_path + # see if we have a checkpoint in out output to resume from + self.print(f"Looking for latest checkpoint in {self.save_root}") + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors")) + files += glob.glob(os.path.join(self.save_root, f"{self.job.name}*.pth")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + # todo update step and epoch count + elif self.pretrained_path is None: + self.print(f" - No checkpoint found, starting from scratch") + else: + self.print(f" - No checkpoint found, loading pretrained model") + self.print(f" - path: {path_to_load}") + + if path_to_load is not None: + self.print(f" - Loading pretrained checkpoint: {path_to_load}") + # if ends with pth then assume pytorch checkpoint + if path_to_load.endswith('.pth') or path_to_load.endswith('.pt'): + state_dict = torch.load(path_to_load, map_location=self.device) + elif path_to_load.endswith('.safetensors'): + state_dict_raw = load_file(path_to_load) + # make ordered dict as most things need it + state_dict = OrderedDict() + for key in esrgan_safetensors_keys: + state_dict[key] = state_dict_raw[key] + else: + raise Exception(f"Unknown file extension for checkpoint: {path_to_load}") + + # todo determine architecture from checkpoint + self.model = ESRGAN( + state_dict + ).to(self.device, dtype=self.esrgan_dtype) + + # set the model to training mode + self.model.train() + self.model.requires_grad_(True) + + def run(self): + super().run() + self.load_datasets() + steps_per_step = (self.critic.num_critic_per_gen + 1) + + max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step) + num_epochs = self.epochs + if num_epochs is None or num_epochs > max_step_epochs: + num_epochs = max_step_epochs + + max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step + num_steps = self.max_steps + if num_steps is None or num_steps > max_epoch_steps: + num_steps = max_epoch_steps + self.max_steps = num_steps + self.epochs = num_epochs + start_step = self.step_num + self.first_step = start_step + + self.print(f"Training ESRGAN model:") + self.print(f" - Training folder: {self.training_folder}") + self.print(f" - Batch size: {self.batch_size}") + self.print(f" - Learning rate: {self.learning_rate}") + self.print(f" - Epochs: {num_epochs}") + self.print(f" - Max steps: {self.max_steps}") + + # load model + self.load_model() + + params = self.model.parameters() + + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + self.setup_vgg19() + self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + if self.use_critic: + self.critic.setup() + + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + + # setup scheduler + # todo allow other schedulers + scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, + total_iters=num_steps, + factor=1, + verbose=False + ) + + # setup tqdm progress bar + self.progress_bar = tqdm( + total=num_steps, + desc='Training ESRGAN', + leave=True + ) + + blank_losses = OrderedDict({ + "total": [], + "style": [], + "content": [], + "mse": [], + "kl": [], + "tv": [], + "ptn": [], + "crD": [], + "crG": [], + }) + epoch_losses = copy.deepcopy(blank_losses) + log_losses = copy.deepcopy(blank_losses) + print("Generating baseline samples") + self.sample(step=0) + # range start at self.epoch_num go to self.epochs + critic_losses = [] + for epoch in range(self.epoch_num, self.epochs, 1): + if self.step_num >= self.max_steps: + break + flush() + for targets, inputs in self.data_loader: + if self.step_num >= self.max_steps: + break + with torch.no_grad(): + is_critic_only_step = False + if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform(): + is_critic_only_step = True + + targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach() + inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach() + + optimizer.zero_grad() + # dont do grads here for critic step + do_grad = not is_critic_only_step + with torch.set_grad_enabled(do_grad): + pred = self.model(inputs) + + pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + if torch.isnan(pred).any(): + raise ValueError('pred has nan values') + if torch.isnan(targets).any(): + raise ValueError('targets has nan values') + + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + stacked = torch.cat([pred, targets], dim=0) + # stacked = (stacked / 2 + 0.5).clamp(0, 1) + stacked = stacked.clamp(0, 1) + self.vgg_19(stacked) + # make sure we dont have nans + if torch.isnan(self.vgg19_pool_4.tensor).any(): + raise ValueError('vgg19_pool_4 has nan values') + + if is_critic_only_step: + critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + critic_losses.append(critic_d_loss) + # don't do generator step + continue + else: + # doing a regular step + if len(critic_losses) == 0: + critic_d_loss = 0 + else: + critic_d_loss = sum(critic_losses) / len(critic_losses) + + style_loss = self.get_style_loss() * self.style_weight + content_loss = self.get_content_loss() * self.content_weight + + mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight + tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight + pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight + if self.use_critic: + critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + else: + critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + # make sure non nan + if torch.isnan(loss): + raise ValueError('loss is nan') + + # Backward pass and optimization + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + # update progress bar + loss_value = loss.item() + # get exponent like 3.54e-4 + loss_string = f"loss: {loss_value:.2e}" + if self.content_weight > 0: + loss_string += f" cnt: {content_loss.item():.2e}" + if self.style_weight > 0: + loss_string += f" sty: {style_loss.item():.2e}" + if self.mse_weight > 0: + loss_string += f" mse: {mse_loss.item():.2e}" + if self.tv_weight > 0: + loss_string += f" tv: {tv_loss.item():.2e}" + if self.pattern_weight > 0: + loss_string += f" ptn: {pattern_loss.item():.2e}" + if self.use_critic and self.critic_weight > 0: + loss_string += f" crG: {critic_gen_loss.item():.2e}" + if self.use_critic: + loss_string += f" crD: {critic_d_loss:.2e}" + + if self.optimizer_type.startswith('dadaptation') or self.optimizer_type.startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + lr_critic_string = '' + if self.use_critic: + lr_critic = self.critic.get_lr() + lr_critic_string = f" lrC: {lr_critic:.1e}" + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") + self.progress_bar.set_description(f"E: {epoch}") + self.progress_bar.update(1) + + epoch_losses["total"].append(loss_value) + epoch_losses["style"].append(style_loss.item()) + epoch_losses["content"].append(content_loss.item()) + epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["ptn"].append(pattern_loss.item()) + epoch_losses["crG"].append(critic_gen_loss.item()) + epoch_losses["crD"].append(critic_d_loss) + + log_losses["total"].append(loss_value) + log_losses["style"].append(style_loss.item()) + log_losses["content"].append(content_loss.item()) + log_losses["mse"].append(mse_loss.item()) + log_losses["tv"].append(tv_loss.item()) + log_losses["ptn"].append(pattern_loss.item()) + log_losses["crG"].append(critic_gen_loss.item()) + log_losses["crD"].append(critic_d_loss) + + # don't do on first step + if self.step_num != start_step: + if self.sample_every and self.step_num % self.sample_every == 0: + # print above the progress bar + self.print(f"Sampling at step {self.step_num}") + self.sample(self.step_num, batch=[targets, inputs]) + + if self.save_every and self.step_num % self.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + + if self.log_every and self.step_num % self.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + for key in log_losses: + log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) + # if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) + # reset log losses + log_losses = copy.deepcopy(blank_losses) + + self.step_num += 1 + # end epoch + if self.writer is not None: + eps = 1e-6 + # get avg loss + for key in epoch_losses: + epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) + if epoch_losses[key] > 0: + self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) + # reset epoch losses + epoch_losses = copy.deepcopy(blank_losses) + + self.save() diff --git a/ai-toolkit/jobs/process/TrainFineTuneProcess.py b/ai-toolkit/jobs/process/TrainFineTuneProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..a13a7cf640ad2a695d2f330a8cb4636985593376 --- /dev/null +++ b/ai-toolkit/jobs/process/TrainFineTuneProcess.py @@ -0,0 +1,13 @@ +from collections import OrderedDict +from jobs import TrainJob +from jobs.process import BaseTrainProcess + + +class TrainFineTuneProcess(BaseTrainProcess): + def __init__(self,process_id: int, job: TrainJob, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + # implement in child class + # be sure to call super().run() first + pass diff --git a/ai-toolkit/jobs/process/TrainSDRescaleProcess.py b/ai-toolkit/jobs/process/TrainSDRescaleProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..ac888118056c059264466f1bcccff0542deb0fbe --- /dev/null +++ b/ai-toolkit/jobs/process/TrainSDRescaleProcess.py @@ -0,0 +1,276 @@ +import glob +import os +from collections import OrderedDict +import random +from typing import Optional, List + +from safetensors.torch import save_file, load_file +from tqdm import tqdm + +from toolkit.layers import ReductionKernel +from toolkit.stable_diffusion_model import PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc +from toolkit import train_tools + +import torch +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class RescaleConfig: + def __init__( + self, + **kwargs + ): + self.from_resolution = kwargs.get('from_resolution', 512) + self.scale = kwargs.get('scale', 0.5) + self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None) + self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000) + self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale)) + self.prompt_dropout = kwargs.get('prompt_dropout', 0.1) + + +class PromptEmbedsCache: + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class TrainSDRescaleProcess(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + # pass our custom pipeline to super so it sets it up + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True)) + self.reduce_size_fn = ReductionKernel( + in_channels=4, + kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution), + dtype=get_torch_dtype(self.train_config.dtype), + device=self.device_torch, + ) + + self.latent_paths: List[str] = [] + self.empty_embedding: PromptEmbeds = None + + def before_model_load(self): + pass + + def get_latent_tensors(self): + dtype = get_torch_dtype(self.train_config.dtype) + + num_to_generate = 0 + # check if dir exists + if not os.path.exists(self.rescale_config.latent_tensor_dir): + os.makedirs(self.rescale_config.latent_tensor_dir) + num_to_generate = self.rescale_config.num_latent_tensors + else: + # find existing + current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors")) + num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list) + self.latent_paths = current_tensor_list + + if num_to_generate > 0: + print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors") + + # unload other model + self.sd.unet.to('cpu') + + # load aux network + self.sd_parent = StableDiffusion( + self.device_torch, + model_config=self.model_config, + dtype=self.train_config.dtype, + ) + self.sd_parent.load_model() + self.sd_parent.unet.to(self.device_torch, dtype=dtype) + # we dont need text encoder for this + + del self.sd_parent.text_encoder + del self.sd_parent.tokenizer + + self.sd_parent.unet.eval() + self.sd_parent.unet.requires_grad_(False) + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + text_embeddings = train_tools.concat_prompt_embeddings( + self.empty_embedding, # unconditional (negative prompt) + self.empty_embedding, # conditional (positive prompt) + self.train_config.batch_size, + ) + torch.set_default_device(self.device_torch) + + for i in tqdm(range(num_to_generate)): + dtype = get_torch_dtype(self.train_config.dtype) + # get a random seed + seed = torch.randint(0, 2 ** 32, (1,)).item() + # zero pad seed string to max length + seed_string = str(seed).zfill(10) + # set seed + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # # ger a random number of steps + timesteps_to = self.train_config.max_denoising_steps + + # set the scheduler to the number of steps + self.sd.noise_scheduler.set_timesteps( + timesteps_to, device=self.device_torch + ) + + noise = self.sd.get_latent_noise( + pixel_height=self.rescale_config.from_resolution, + pixel_width=self.rescale_config.from_resolution, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + # get random guidance scale from 1.0 to 10.0 (CFG) + guidance_scale = torch.rand(1).item() * 9.0 + 1.0 + + # do a timestep of 1 + timestep = 1 + + noise_pred_target = self.sd_parent.predict_noise( + latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + + # build state dict + state_dict = OrderedDict() + state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16) + state_dict['latents'] = latents.to('cpu', dtype=torch.float16) + state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16) + state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16) + state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16) + state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow + + file_name = f"{seed_string}_{i}.safetensors" + file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name) + save_file(state_dict, file_path) + self.latent_paths.append(file_path) + + print("Removing parent model") + # delete parent + del self.sd_parent + flush() + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + self.sd.unet.to(self.device_torch, dtype=dtype) + + def hook_before_train_loop(self): + # encode our empty prompt + self.empty_embedding = self.sd.encode_prompt("") + self.empty_embedding = self.empty_embedding.to(self.device_torch, + dtype=get_torch_dtype(self.train_config.dtype)) + + # Move train model encoder to cpu + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to('cpu') + encoder.eval() + encoder.requires_grad_(False) + else: + self.sd.text_encoder.to('cpu') + self.sd.text_encoder.eval() + self.sd.text_encoder.requires_grad_(False) + + # self.sd.unet.to('cpu') + flush() + + self.get_latent_tensors() + + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + + loss_function = torch.nn.MSELoss() + + # train it + # Begin gradient accumulation + self.sd.unet.train() + self.sd.unet.requires_grad_(True) + self.sd.unet.to(self.device_torch, dtype=dtype) + + with torch.no_grad(): + self.optimizer.zero_grad() + + # pick random latent tensor + latent_path = random.choice(self.latent_paths) + latent_tensor = load_file(latent_path) + + noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype) + latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype) + guidance_scale = (latent_tensor['guidance_scale']).item() + timestep = int((latent_tensor['timestep']).item()) + timesteps_to = int((latent_tensor['timesteps_to']).item()) + # seed = int((latent_tensor['seed']).item()) + + text_embeddings = train_tools.concat_prompt_embeddings( + self.empty_embedding, # unconditional (negative prompt) + self.empty_embedding, # conditional (positive prompt) + self.train_config.batch_size, + ) + self.sd.noise_scheduler.set_timesteps( + timesteps_to, device=self.device_torch + ) + + denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample + + # get the reduced latents + # reduced_pred = self.reduce_size_fn(noise_pred_target.detach()) + denoised_target = self.reduce_size_fn(denoised_target.detach()) + reduced_latents = self.reduce_size_fn(latents.detach()) + + denoised_target.requires_grad = False + self.optimizer.zero_grad() + noise_pred_train = self.sd.predict_noise( + reduced_latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample + loss = loss_function(denoised_pred, denoised_target) + loss_float = loss.item() + loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + flush() + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + + return loss_dict + # end hook_train_loop diff --git a/ai-toolkit/jobs/process/TrainSliderProcess.py b/ai-toolkit/jobs/process/TrainSliderProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..eddc9838f1ecc66270b2b5efbe8592864668117c --- /dev/null +++ b/ai-toolkit/jobs/process/TrainSliderProcess.py @@ -0,0 +1,691 @@ +import copy +import os +import random +from collections import OrderedDict +from typing import Union + +from PIL import Image +from diffusers import T2IAdapter +from torchvision.transforms import transforms +from tqdm import tqdm + +from toolkit.basic import value_map +from toolkit.config_modules import SliderConfig +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos +import gc +from toolkit import train_tools +from toolkit.prompt_utils import \ + EncodedPromptPair, ACTION_TYPES_SLIDER, \ + EncodedAnchor, concat_prompt_pairs, \ + concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \ + split_prompt_pairs + +import torch +from .BaseSDTrainProcess import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class TrainSliderProcess(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = SliderConfig(**self.get_conf('slider', {})) + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + # keep track of prompt chunk size + self.prompt_chunk_size = 1 + + # check if we have more targets than steps + # this can happen because of permutation son shuffling + if len(self.slider_config.targets) > self.train_config.steps: + # trim targets + self.slider_config.targets = self.slider_config.targets[:self.train_config.steps] + + # get presets + self.eval_slider_device_state = get_train_sd_device_state_preset( + self.device_torch, + train_unet=False, + train_text_encoder=False, + cached_latents=self.is_latents_cached, + train_lora=False, + train_adapter=False, + train_embedding=False, + ) + + self.train_slider_device_state = get_train_sd_device_state_preset( + self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=False, + cached_latents=self.is_latents_cached, + train_lora=True, + train_adapter=False, + train_embedding=False, + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + + # read line by line from file + if self.slider_config.prompt_file: + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") + with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Found {len(self.prompt_txt_list)} prompts.") + + if not self.slider_config.prompt_tensors: + print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.") + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + + cache = PromptEmbedsCache() + print(f"Building prompt cache") + + # get encoded latents for our prompts + with torch.no_grad(): + # list of neutrals. Can come from file or be empty + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + # build the prompts to cache + prompts_to_cache = [] + for neutral in neutral_list: + for target in self.slider_config.targets: + prompt_list = [ + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets reverse + ] + prompts_to_cache += prompt_list + + # remove duplicates + prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + + # trim to max steps if max steps is lower than prompt count + # todo, this can break if we have more targets than steps, should be fixed, by reducing permuations, but could stil happen with low steps + # prompts_to_cache = prompts_to_cache[:self.train_config.steps] + + # encode them + cache = encode_prompts_to_cache( + prompt_list=prompts_to_cache, + sd=self.sd, + cache=cache, + prompt_tensor_file=self.slider_config.prompt_tensors + ) + + prompt_pairs = [] + prompt_batches = [] + for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False): + for target in self.slider_config.targets: + prompt_pair_batch = build_prompt_pair_batch_from_cache( + cache=cache, + target=target, + neutral=neutral, + + ) + if self.slider_config.batch_full_slide: + # concat the prompt pairs + # this allows us to run the entire 4 part process in one shot (for slider) + self.prompt_chunk_size = 4 + concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu') + prompt_pairs += [concat_prompt_pair_batch] + else: + self.prompt_chunk_size = 1 + # do them one at a time (probably not necessary after new optimizations) + prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] + + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = self.sd.encode_prompt(prompt) + + anchor_batch = [] + # we get the prompt pair multiplier from first prompt pair + # since they are all the same. We need to match their network polarity + prompt_pair_multipliers = prompt_pairs[0].multiplier_list + for prompt_multiplier in prompt_pair_multipliers: + # match the network multiplier polarity + anchor_scalar = 1.0 if prompt_multiplier > 0 else -1.0 + anchor_batch += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier * anchor_scalar + ) + ] + + anchor_pairs += [ + concat_anchors(anchor_batch).to('cpu') + ] + if len(anchor_pairs) > 0: + self.anchor_pairs = anchor_pairs + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + # self.anchor_pairs = anchor_pairs + flush() + if self.data_loader is not None: + # we will have images, prep the vae + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + # end hook_before_train_loop + + def before_dataset_load(self): + if self.slider_config.use_adapter == 'depth': + print(f"Loading T2I Adapter for depth") + # called before LoRA network is loaded but after model is loaded + # attach the adapter here so it is there before we load the network + adapter_path = 'TencentARC/t2iadapter_depth_sd15v2' + if self.model_config.is_xl: + adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0' + + print(f"Loading T2I Adapter from {adapter_path}") + + # dont name this adapter since we are not training it + self.t2i_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16" + ).to(self.device_torch) + self.t2i_adapter.eval() + self.t2i_adapter.requires_grad_(False) + flush() + + @torch.no_grad() + def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']): + + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + adapter_folder_path = self.slider_config.adapter_img_dir + adapter_images = [] + # loop through images + for file_item in batch.file_items: + img_path = file_item.path + file_name_no_ext = os.path.basename(img_path).split('.')[0] + # find the image + for ext in img_ext_list: + if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): + adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) + break + width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height + adapter_tensors = [] + # load images with torch transforms + for idx, adapter_image in enumerate(adapter_images): + # we need to centrally crop the largest dimension of the image to match the batch shape after scaling + # to the smallest dimension + img: Image.Image = Image.open(adapter_image) + if img.width > img.height: + # scale down so height is the same as batch + new_height = height + new_width = int(img.width * (height / img.height)) + else: + new_width = width + new_height = int(img.height * (width / img.width)) + + img = img.resize((new_width, new_height)) + crop_fn = transforms.CenterCrop((height, width)) + # crop the center to match batch + img = crop_fn(img) + img = adapter_transforms(img) + adapter_tensors.append(img) + + # stack them + adapter_tensors = torch.stack(adapter_tensors).to( + self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) + ) + return adapter_tensors + + def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]): + if isinstance(batch, list): + batch = batch[0] + # set to eval mode + self.sd.set_device_state(self.eval_slider_device_state) + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) + + # get a random resolution + height, width = self.slider_config.resolutions[ + torch.randint(0, len(self.slider_config.resolutions), (1,)).item() + ] + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + loss_function = torch.nn.MSELoss() + + pred_kwargs = {} + + def get_noise_pred(neg, pos, gs, cts, dn): + down_kwargs = copy.deepcopy(pred_kwargs) + if 'down_block_additional_residuals' in down_kwargs: + dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0] + if dbr_batch_size != dn.shape[0]: + amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size) + down_kwargs['down_block_additional_residuals'] = [ + torch.cat([sample.clone()] * amount_to_add) for sample in + down_kwargs['down_block_additional_residuals'] + ] + return self.sd.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + neg, # negative prompt + pos, # positive prompt + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + **down_kwargs + ) + + with torch.no_grad(): + adapter_images = None + self.sd.unet.eval() + + # for a complete slider, the batch size is 4 to begin with now + true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size + from_batch = False + if batch is not None: + # traing from a batch of images, not generating ourselves + from_batch = True + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.slider_config.adapter_img_dir is not None: + adapter_images = self.get_adapter_images(batch) + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + + def rand_strength(sample): + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale + + down_block_additional_residuals = self.t2i_adapter(adapter_images) + down_block_additional_residuals = [ + rand_strength(sample) for sample in down_block_additional_residuals + ] + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + + # denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0) + denoised_latents = noisy_latents + current_timestep = timesteps + else: + if self.train_config.noise_scheduler == 'flowmatch': + linear_timesteps = any([ + self.train_config.linear_timesteps, + self.train_config.linear_timesteps2, + self.train_config.timestep_type == 'linear', + ]) + + timestep_type = 'linear' if linear_timesteps else None + if timestep_type is None: + timestep_type = self.train_config.timestep_type + + # make fake latents + l = torch.randn( + true_batch_size, 16, height, width + ).to(self.device_torch, dtype=dtype) + + self.sd.noise_scheduler.set_train_timesteps( + self.train_config.max_denoising_steps, + device=self.device_torch, + timestep_type=timestep_type, + latents=l + ) + else: + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps - 1, (1,) + ).item() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=true_batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + assert not self.network.is_active + self.sd.unet.eval() + # pass the multiplier list to the network + # double up since we are doing cfg + self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list + denoised_latents = self.sd.diffuse_some_steps( + latents, # pass simple noise latents + prompt_pair.target_class, + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + bypass_guidance_embedding=False + ) + if hasattr(self.sd.noise_scheduler, 'set_train_timesteps'): + noise_scheduler.set_train_timesteps(1000, device=self.device_torch) + else: + noise_scheduler.set_timesteps(1000) + + current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] + + # split the latents into out prompt pair chunks + # denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) + # denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks] + denoised_latent_chunks = [denoised_latents] + + # flush() # 4.2GB to 3GB on 512x512 + mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + has_mask = False + if batch and batch.mask_tensor is not None: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + has_mask = True + + if has_mask: + unmasked_target = get_noise_pred( + prompt_pair.positive_target, # negative prompt + prompt_pair.target_class, # positive prompt + 1, + current_timestep, + denoised_latents + ) + unmasked_target = unmasked_target.detach() + unmasked_target.requires_grad = False + else: + unmasked_target = None + + # 4.20 GB RAM for 512x512 + # positive_latents = get_noise_pred( + # prompt_pair.positive_target, # negative prompt + # prompt_pair.negative_target, # positive prompt + # 1, + # current_timestep, + # denoised_latents + # ) + # positive_latents = positive_latents.detach() + # positive_latents.requires_grad = False + + # neutral_latents = get_noise_pred( + # prompt_pair.positive_target, # negative prompt + # prompt_pair.empty_prompt, # positive prompt (normally neutral + # 1, + # current_timestep, + # denoised_latents + # ) + # neutral_latents = neutral_latents.detach() + # neutral_latents.requires_grad = False + + # unconditional_latents = get_noise_pred( + # prompt_pair.positive_target, # negative prompt + # prompt_pair.positive_target, # positive prompt + # 1, + # current_timestep, + # denoised_latents + # ) + # unconditional_latents = unconditional_latents.detach() + # unconditional_latents.requires_grad = False + + # we just need positive target, negative target, and empty prompt to calculate all + # since we are in no grad, we can easily do it in a single step + embeddings = train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, + prompt_pair.empty_prompt, + 1 + ) + embeddings = train_tools.concat_prompt_embeddings( + embeddings, + prompt_pair.negative_target, + 1 + ) + all_pred = self.sd.predict_noise( + latents=torch.cat([denoised_latents] * 3, dim=0), + text_embeddings=embeddings, + timestep=torch.cat([current_timestep] * 3, dim=0), + ) + all_pred = all_pred.detach() + all_pred.requires_grad = False + positive_pred, neutral_pred, unconditional_pred = torch.chunk(all_pred, 3, dim=0) + + # doing them backward here as it was originally for erasing + positive_latents = unconditional_pred + neutral_latents = neutral_pred + unconditional_latents = positive_pred + + + denoised_latents = denoised_latents.detach() + + self.sd.set_device_state(self.train_slider_device_state) + self.sd.unet.train() + # start accumulating gradients + self.optimizer.zero_grad(set_to_none=True) + + anchor_loss_float = None + + with torch.no_grad(): + if self.slider_config.low_ram: + prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size) + denoised_latent_chunks = denoised_latent_chunks # just to have it in one place + positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk( + unconditional_latents.detach(), + self.prompt_chunk_size, + dim=0 + ) + mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0) + if unmasked_target is not None: + unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0) + else: + unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)] + else: + # run through in one instance + prompt_pair_chunks = [prompt_pair.detach()] + denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()] + positive_latents_chunks = [positive_latents.detach()] + neutral_latents_chunks = [neutral_latents.detach()] + unconditional_latents_chunks = [unconditional_latents.detach()] + mask_multiplier_chunks = [mask_multiplier] + unmasked_target_chunks = [unmasked_target] + + # flush() + assert len(prompt_pair_chunks) == len(denoised_latent_chunks) + # 3.28 GB RAM for 512x512 + with self.network: + assert self.network.is_active + loss_list = [] + for prompt_pair_chunk, \ + denoised_latent_chunk, \ + positive_latents_chunk, \ + neutral_latents_chunk, \ + unconditional_latents_chunk, \ + mask_multiplier_chunk, \ + unmasked_target_chunk \ + in zip( + prompt_pair_chunks, + denoised_latent_chunks, + positive_latents_chunks, + neutral_latents_chunks, + unconditional_latents_chunks, + mask_multiplier_chunks, + unmasked_target_chunks + ): + self.network.multiplier = prompt_pair_chunk.multiplier_list + + target_latents = self.sd.predict_noise( + latents=denoised_latent_chunk.detach(), + text_embeddings=prompt_pair_chunk.target_class, + timestep=current_timestep, + ) + + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk) + + # make offset multiplier based on actions + offset_multiplier_list = [] + for action in prompt_pair_chunk.action_list: + if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE: + offset_multiplier_list += [-1.0] + elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE: + offset_multiplier_list += [1.0] + + offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype) + # make offset multiplier match rank of offset + offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1) + offset *= offset_multiplier + + offset_neutral = neutral_latents_chunk + # offsets are already adjusted on a per-batch basis + offset_neutral += offset + offset_neutral = offset_neutral.detach().requires_grad_(False) + + # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing + loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + + # do inverted mask to preserve non masked + if has_mask and unmasked_target_chunk is not None: + loss = loss * mask_multiplier_chunk + # match the mask unmasked_target_chunk + mask_target_loss = torch.nn.functional.mse_loss( + target_latents.float(), + unmasked_target_chunk.float(), + reduction="none" + ) + mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk) + loss += mask_target_loss + + loss = loss.mean([1, 2, 3]) + + if self.train_config.learnable_snr_gos: + if from_batch: + # match batch size + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, + self.train_config.min_snr_gamma) + else: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos) + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + if from_batch: + # match batch size + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, + self.train_config.min_snr_gamma) + else: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, + self.train_config.min_snr_gamma) + + + loss = loss.mean() * prompt_pair_chunk.weight + + loss.backward() + loss_list.append(loss.item()) + del target_latents + del offset_neutral + del loss + # flush() + + optimizer.step() + lr_scheduler.step() + + loss_float = sum(loss_list) / len(loss_list) + if anchor_loss_float is not None: + loss_float += anchor_loss_float + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + # latents + ) + # move back to cpu + prompt_pair.to("cpu") + # flush() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + if anchor_loss_float is not None: + loss_dict['sl_l'] = loss_float + loss_dict['an_l'] = anchor_loss_float + + return loss_dict + # end hook_train_loop diff --git a/ai-toolkit/jobs/process/TrainSliderProcessOld.py b/ai-toolkit/jobs/process/TrainSliderProcessOld.py new file mode 100644 index 0000000000000000000000000000000000000000..851bbf4a49c2f0618c8f326ed57202d9e5fc316e --- /dev/null +++ b/ai-toolkit/jobs/process/TrainSliderProcessOld.py @@ -0,0 +1,404 @@ +# ref: +# - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import time +from collections import OrderedDict +import os +from typing import Optional + +from toolkit.config_modules import SliderConfig +import sys + +from toolkit.stable_diffusion_model import PromptEmbeds + +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc +from toolkit import train_tools + +import torch +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class EncodedPromptPair: + def __init__( + self, + target_class, + positive, + negative, + neutral, + width=512, + height=512, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=1.0, + weight=1.0 + ): + self.target_class = target_class + self.positive = positive + self.negative = negative + self.neutral = neutral + self.width = width + self.height = height + self.action: int = action + self.multiplier = multiplier + self.weight = weight + + +class PromptEmbedsCache: # 使いまわしたいので + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0 + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + +class TrainSliderProcessOld(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = SliderConfig(**self.get_conf('slider', {})) + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + cache = PromptEmbedsCache() + prompt_pairs: list[EncodedPromptPair] = [] + + # get encoded latents for our prompts + with torch.no_grad(): + neutral = "" + for target in self.slider_config.targets: + # build the cache + for prompt in [ + target.target_class, + target.positive, + target.negative, + neutral # empty neutral + ]: + if cache[prompt] is None: + cache[prompt] = self.sd.encode_prompt(prompt) + for resolution in self.slider_config.resolutions: + width, height = resolution + only_erase = len(target.positive.strip()) == 0 + only_enhance = len(target.negative.strip()) == 0 + + both = not only_erase and not only_enhance + + if only_erase and only_enhance: + raise ValueError("target must have at least one of positive or negative or both") + # for slider we need to have an enhancer, an eraser, and then + # an inverse with negative weights to balance the network + # if we don't do this, we will get different contrast and focus. + # we only perform actions of enhancing and erasing on the negative + # todo work on way to do all of this in one shot + + if both or only_erase: + prompt_pairs += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both or only_enhance: + prompt_pairs += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both: + prompt_pairs += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + prompt_pairs += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = self.sd.encode_prompt(prompt) + + anchor_pairs += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier + ) + ] + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + self.anchor_pairs = anchor_pairs + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + + height = prompt_pair.height + width = prompt_pair.width + target_class = prompt_pair.target_class + neutral = prompt_pair.neutral + negative = prompt_pair.negative + positive = prompt_pair.positive + weight = prompt_pair.weight + multiplier = prompt_pair.multiplier + + unet = self.sd.unet + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + loss_function = torch.nn.MSELoss() + + def get_noise_pred(p, n, gs, cts, dn): + return self.sd.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + p, # unconditional + n, # positive + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + ) + + # set network multiplier + self.network.multiplier = multiplier + + with torch.no_grad(): + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + self.optimizer.zero_grad() + + # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps, (1,) + ).item() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + with self.network: + assert self.network.is_active + self.network.multiplier = multiplier + denoised_latents = self.sd.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + positive, # unconditional + target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + ] + + positive_latents = get_noise_pred( + positive, negative, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + neutral_latents = get_noise_pred( + positive, neutral, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + unconditional_latents = get_noise_pred( + positive, positive, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + anchor_loss = None + if len(self.anchor_pairs) > 0: + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + with torch.no_grad(): + anchor_target_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + with self.network: + # anchor whatever weight prompt pair is using + pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 + self.network.multiplier = anchor.multiplier * pos_nem_mult + + anchor_pred_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + self.network.multiplier = prompt_pair.multiplier + + with self.network: + self.network.multiplier = prompt_pair.multiplier + target_latents = get_noise_pred( + positive, target_class, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + # if self.logging_config.verbose: + # self.print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + if len(self.anchor_pairs) > 0: + anchor_target_noise.requires_grad = False + anchor_loss = loss_function( + anchor_target_noise, + anchor_pred_noise, + ) + erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents - unconditional_latents) + + offset_neutral = neutral_latents + if erase: + offset_neutral -= offset + else: + # enhance + offset_neutral += offset + + loss = loss_function( + target_latents, + offset_neutral, + ) * weight + + loss_slide = loss.item() + + if anchor_loss is not None: + loss += anchor_loss + + loss_float = loss.item() + + loss = loss.to(self.device_torch) + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, + ) + flush() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + if anchor_loss is not None: + loss_dict['sl_l'] = loss_slide + loss_dict['an_l'] = anchor_loss.item() + + return loss_dict + # end hook_train_loop diff --git a/ai-toolkit/jobs/process/TrainVAEProcess.py b/ai-toolkit/jobs/process/TrainVAEProcess.py new file mode 100644 index 0000000000000000000000000000000000000000..95c38b4305dd1e2ab5d0998ccb15faca0cba601f --- /dev/null +++ b/ai-toolkit/jobs/process/TrainVAEProcess.py @@ -0,0 +1,1142 @@ +import copy +import glob +import os +import shutil +import time +from collections import OrderedDict + +from PIL import Image +from PIL.ImageOps import exif_transpose +from einops import rearrange +from safetensors.torch import save_file, load_file +from torch.utils.data import DataLoader, ConcatDataset +import torch +from torch import nn +from torchvision.transforms import transforms + +from jobs.process import BaseTrainProcess +from toolkit.image_utils import show_tensors +from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm +from toolkit.data_loader import ImageDataset +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.style import get_style_model_and_losses +from toolkit.train_tools import get_torch_dtype +from diffusers import AutoencoderKL, AutoencoderTiny +from toolkit.models.autoencoder_tiny_with_pooled_exits import AutoencoderTinyWithPooledExits +from tqdm import tqdm +import math +import torchvision.utils +import time +import numpy as np +from .models.critic import Critic +from torchvision.transforms import Resize +import lpips +import random +import traceback +from transformers import SiglipImageProcessor, SiglipVisionModel +import torch.nn.functional as F + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def unnormalize(tensor): + return (tensor / 2 + 0.5).clamp(0, 1) + + +def channel_dropout(x, p=0.5): + keep_prob = 1 - p + mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob + mask = mask / keep_prob # scale + return x * mask + + +def sharpen_image(images: torch.Tensor) -> torch.Tensor: + # Define sharpening kernel + kernel = torch.tensor([ + [ 0, -1, 0], + [-1, 5, -1], + [ 0, -1, 0] + ], dtype=images.dtype, device=images.device).view(1, 1, 3, 3) + + # Repeat kernel for each channel + kernel = kernel.repeat(3, 1, 1, 1) # (out_channels, in_channels/groups, kH, kW) + + # Apply the filter + sharpened = F.conv2d(images, kernel, padding=1, groups=3) + + return sharpened + +class TrainVAEProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.data_loader = None + self.vae = None + self.target_latent_vae = None + self.device = self.get_conf('device', self.job.device) + self.vae_path = self.get_conf('vae_path', None) + self.target_latent_vae_path = self.get_conf('target_latent_vae_path', None) + self.eq_vae = self.get_conf('eq_vae', False) + self.datasets_objects = self.get_conf('datasets', required=True) + self.batch_size = self.get_conf('batch_size', 1, as_type=int) + self.resolution = self.get_conf('resolution', 256, as_type=int) + self.sample_resolution = self.get_conf('sample_resolution', self.resolution, as_type=int) + self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) + self.sample_every = self.get_conf('sample_every', None) + self.optimizer_type = self.get_conf('optimizer', 'adam') + self.epochs = self.get_conf('epochs', None, as_type=int) + self.max_steps = self.get_conf('max_steps', None, as_type=int) + self.save_every = self.get_conf('save_every', None) + self.dtype = self.get_conf('dtype', 'float32') + self.sample_sources = self.get_conf('sample_sources', None) + self.log_every = self.get_conf('log_every', 100, as_type=int) + self.style_weight = self.get_conf('style_weight', 0, as_type=float) + self.content_weight = self.get_conf('content_weight', 0, as_type=float) + self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) + self.clip_weight = self.get_conf('clip_weight', 0, as_type=float) + self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.mae_weight = self.get_conf('mae_weight', 0, as_type=float) + self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float) + self.tv_weight = self.get_conf('tv_weight', 0, as_type=float) + self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float) + self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching + self.lpips_weight = self.get_conf('lpips_weight', 0, as_type=float) + self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float) + self.optimizer_params = self.get_conf('optimizer_params', {}) + self.vae_config = self.get_conf('vae_config', None) + self.dropout = self.get_conf('dropout', 0.0, as_type=float) + self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) + self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) + self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny + self.only_if_contains = self.get_conf('only_if_contains', None) + + self.do_pooled_exits = False + self.VaeClass = AutoencoderKL + if self.vae_type == 'AutoencoderTiny': + self.VaeClass = AutoencoderTiny + if self.vae_type == 'AutoencoderTinyWithPooledExits': + self.VaeClass = AutoencoderTinyWithPooledExits + self.do_pooled_exits = True + + if not self.train_encoder: + # remove losses that only target encoder + self.kld_weight = 0 + self.mv_loss_weight = 0 + self.ltv_weight = 0 + self.lpm_weight = 0 + + self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) + self.torch_dtype = get_torch_dtype(self.dtype) + self.vgg_19 = None + self.clip = None + self.clip_image_processor = None + self.clip_image_size = 256 + self.style_weight_scalers = [] + self.content_weight_scalers = [] + self.lpips_loss:lpips.LPIPS = None + + self.vae_scale_factor = 8 + self.target_vae_scale_factor = 8 + + self.step_num = 0 + self.epoch_num = 0 + + self.use_critic = self.get_conf('use_critic', False, as_type=bool) + self.critic = None + + if self.use_critic: + self.critic = Critic( + device=self.device, + dtype=self.dtype, + process=self, + **self.get_conf('critic', {}) # pass any other params + ) + + if self.sample_every is not None and self.sample_sources is None: + raise ValueError('sample_every is specified but sample_sources is not') + + if self.epochs is None and self.max_steps is None: + raise ValueError('epochs or max_steps must be specified') + + self.data_loaders = [] + # check datasets + assert isinstance(self.datasets_objects, list) + for dataset in self.datasets_objects: + if 'path' not in dataset: + raise ValueError('dataset must have a path') + # check if is dir + if not os.path.isdir(dataset['path']): + raise ValueError(f"dataset path does is not a directory: {dataset['path']}") + + # make training folder + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + self._pattern_loss = None + + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.datasets_objects: + print(f" - Dataset: {dataset['path']}") + ds = copy.copy(dataset) + dataset_res = self.resolution + if self.random_scaling: + # scale 2x to allow for random scaling + dataset_res = int(dataset_res * 2) + ds['resolution'] = dataset_res + image_dataset = ImageDataset(ds) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=16 + ) + + def remove_oldest_checkpoint(self): + max_to_keep = 4 + folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) + if len(folders) > max_to_keep: + folders.sort(key=os.path.getmtime) + for folder in folders[:-max_to_keep]: + print(f"Removing {folder}") + shutil.rmtree(folder) + # also handle CRITIC_vae_42_000000500.safetensors format for critic + critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors")) + if len(critic_files) > max_to_keep: + critic_files.sort(key=os.path.getmtime) + for file in critic_files[:-max_to_keep]: + print(f"Removing {file}") + os.remove(file) + + def setup_vgg19(self): + if self.vgg_19 is None: + self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( + single_target=True, + device=self.device, + output_layer_name='pool_4', + dtype=self.torch_dtype + ) + self.vgg_19.to(self.device, dtype=self.torch_dtype) + self.vgg_19.requires_grad_(False) + + # we run random noise through first to get layer scalers to normalize the loss per layer + # bs of 2 because we run pred and target through stacked + noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) + self.vgg_19(noise) + for style_loss in self.style_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(style_loss.loss).item() + self.style_weight_scalers.append(scaler) + for content_loss in self.content_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(content_loss.loss).item() + self.content_weight_scalers.append(scaler) + + self.print(f"Style weight scalers: {self.style_weight_scalers}") + self.print(f"Content weight scalers: {self.content_weight_scalers}") + + def setup_clip(self): + ckpt = 'google/siglip2-base-patch16-256' + if self.resolution == 512: + ckpt = 'google/siglip2-so400m-patch16-512' + # ckpt = 'google/siglip2-base-patch16-512' + self.clip_image_size = 512 + self.print(f"Loading CLIP model from {ckpt}") + vision_encoder = SiglipVisionModel.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16).eval() + processor = SiglipImageProcessor.from_pretrained(ckpt) + self.clip = vision_encoder + self.clip_image_processor = processor + + def get_clip_embeddings(self, image_n1p1): + tensors_0_1 = (image_n1p1 + 1) / 2 + # sharpen images + tensors_0_1 = sharpen_image(tensors_0_1) + + tensors_0_1 = tensors_0_1.clamp(0, 1) + + # resize if needed + if tensors_0_1.shape[-2:] != (self.clip_image_size, self.clip_image_size): + tensors_0_1 = torch.nn.functional.interpolate(tensors_0_1, size=(self.clip_image_size, self.clip_image_size), mode='bilinear', align_corners=False) + + mean = torch.tensor([0.5, 0.5, 0.5]).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + std = torch.tensor([0.5, 0.5, 0.5]).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean) / std + + id_embeds = self.clip( + clip_image.to(self.clip.device, dtype=torch.bfloat16), + output_hidden_states=True, + ) + last_hidden_state = id_embeds['last_hidden_state'] + return last_hidden_state + + def get_clip_loss(self, pred, target): + # pred and target come in as -1 to 1. + with torch.no_grad(): + target_embeddings = self.get_clip_embeddings(target).float() + pred_embeddings = self.get_clip_embeddings(pred).float() + return torch.nn.functional.mse_loss(pred_embeddings, target_embeddings) + + def get_pooled_output_loss(self, pooled_outputs, target): + if pooled_outputs is None: + return torch.tensor(0.0, device=self.device) + + # pooled_outputs is a list of tensors, each with shape (batch_size, 3, h, w) + # target is a tensor with shape (batch_size, 3, h, w) + loss = 0.0 + for pooled_output in pooled_outputs: + with torch.no_grad(): + # resize target to match pooled_output size + target_resized = torch.nn.functional.interpolate(target, size=pooled_output.shape[2:], mode='bilinear', align_corners=False) + loss += torch.nn.functional.mse_loss(pooled_output.float(), target_resized.float()) + return loss / len(pooled_outputs) if len(pooled_outputs) > 0 else torch.tensor(0.0, device=self.device) + + def get_style_loss(self): + if self.style_weight > 0: + # scale all losses with loss scalers + loss = torch.sum( + torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_content_loss(self): + if self.content_weight > 0: + # scale all losses with loss scalers + loss = torch.sum(torch.stack( + [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_mse_loss(self, pred, target): + if self.mse_weight > 0: + loss_fn = nn.MSELoss() + loss_normal = loss_fn(pred, target) + + pred_sharp = sharpen_image(pred) + with torch.no_grad(): + target_sharp = sharpen_image(target) + + loss_sharp = loss_fn(pred_sharp, target_sharp) + + return (loss_sharp + loss_normal) / 2 + else: + return torch.tensor(0.0, device=self.device) + + def get_mae_loss(self, pred, target): + if self.mae_weight > 0: + loss_fn = nn.L1Loss() + loss_normal = loss_fn(pred, target) + + pred_sharp = sharpen_image(pred) + with torch.no_grad(): + target_sharp = sharpen_image(target) + loss_sharp = loss_fn(pred_sharp, target_sharp) + return (loss_sharp + loss_normal) / 2 + else: + return torch.tensor(0.0, device=self.device) + + def get_kld_loss(self, mu, log_var): + if self.kld_weight > 0: + # Kullback-Leibler divergence + # added here for full training (not implemented). Not needed for only decoder + # as we are not changing the distribution of the latent space + # normally it would help keep a normal distribution for latents + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence + return KLD + else: + return torch.tensor(0.0, device=self.device) + + def get_mean_variance_loss(self, latents: torch.Tensor): + if self.mv_loss_weight > 0: + # collapse rows into channels + latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1]) + mean_col = latents_col.mean(dim=(2, 3), keepdim=True) + std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_col = (mean_col ** 2).mean() + std_loss_col = ((std_col - 1) ** 2).mean() + + # collapse columns into channels + latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2]) + mean_row = latents_row.mean(dim=(2, 3), keepdim=True) + std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_row = (mean_row ** 2).mean() + std_loss_row = ((std_row - 1) ** 2).mean() + + # do a global one + + mean = latents.mean(dim=(2, 3), keepdim=True) + std = latents.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_global = (mean ** 2).mean() + std_loss_global = ((std - 1) ** 2).mean() + + return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3 + else: + return torch.tensor(0.0, device=self.device) + + def get_ltv_loss(self, latent, images): + # loss to reduce the latent space variance + if self.ltv_weight > 0: + with torch.no_grad(): + images = images.to(latent.device, dtype=latent.dtype) + # resize down to latent size + images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) + + # mean the color channel and then expand to latent size + images = images.mean(dim=1, keepdim=True) + images = images.repeat(1, latent.shape[1], 1, 1) + + # normalize to a mean of 0 and std of 1 + images_mean = images.mean(dim=(2, 3), keepdim=True) + images_std = images.std(dim=(2, 3), keepdim=True) + images = (images - images_mean) / (images_std + 1e-6) + + # now we target the same std of the image for the latent space as to not reduce to 0 + + latent_tv = torch.abs(total_variation_deltas(latent)) + images_tv = torch.abs(total_variation_deltas(images)) + loss = torch.abs(latent_tv - images_tv) # keep it spatially aware + loss = loss.mean(dim=2, keepdim=True) + loss = loss.mean(dim=3, keepdim=True) # mean over height and width + loss = loss.mean(dim=1, keepdim=True) # mean over channels + loss = loss.mean() + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_latent_pixel_matching_loss(self, latent, pixels): + if self.lpm_weight > 0: + with torch.no_grad(): + pixels = pixels.to(latent.device, dtype=latent.dtype) + # resize down to latent size + pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) + + # mean the color channel and then expand to latent size + pixels = pixels.mean(dim=1, keepdim=True) + pixels = pixels.repeat(1, latent.shape[1], 1, 1) + # match the mean std of latent + latent_mean = latent.mean(dim=(2, 3), keepdim=True) + latent_std = latent.std(dim=(2, 3), keepdim=True) + pixels_mean = pixels.mean(dim=(2, 3), keepdim=True) + pixels_std = pixels.std(dim=(2, 3), keepdim=True) + pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean + + return torch.nn.functional.mse_loss(latent.float(), pixels.float()) + + else: + return torch.tensor(0.0, device=self.device) + + def get_tv_loss(self, pred, target): + if self.tv_weight > 0: + get_tv_loss = ComparativeTotalVariation() + loss = get_tv_loss(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_pattern_loss(self, pred, target): + if self._pattern_loss is None: + self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device, + dtype=self.torch_dtype) + loss = torch.mean(self._pattern_loss(pred, target)) + return loss + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + filename = f'{self.job.name}{step_num}_diffusers' + + self.vae = self.vae.to("cpu", dtype=torch.float16) + self.vae.save_pretrained( + save_directory=os.path.join(self.save_root, filename) + ) + self.vae = self.vae.to(self.device, dtype=self.torch_dtype) + + self.print(f"Saved to {os.path.join(self.save_root, filename)}") + + if self.use_critic: + self.critic.save(step) + + self.remove_oldest_checkpoint() + + def sample(self, step=None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + + with torch.no_grad(): + for i, img_url in enumerate(self.sample_sources): + img = exif_transpose(Image.open(img_url)) + img = img.convert('RGB') + # crop if not square + if img.width != img.height: + min_dim = min(img.width, img.height) + img = img.crop((0, 0, min_dim, min_dim)) + # resize + img = img.resize((self.sample_resolution, self.sample_resolution)) + + input_img = img + img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) + img = img + # latent = self.vae.encode(img).latent_dist.sample() + + target_latent = None + if self.target_latent_vae is not None: + target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor + target_input_size = (int(img.shape[2] * target_input_scale), int(img.shape[3] * target_input_scale)) + # resize to target input size + target_input_batch = Resize(target_input_size)(img).to(self.device, dtype=torch.float32) + target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach() + shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0 + target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift) + target_latent = target_latent.to(self.device, dtype=self.torch_dtype) + latent = self.vae.encode(img, return_dict=False)[0] + + if hasattr(latent, 'sample'): + latent = latent.sample() + + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latent = self.vae.config['scaling_factor'] * (latent - shift) + + latent_img = latent.clone() + bs, ch, h, w = latent_img.shape + grid_size = math.ceil(math.sqrt(ch)) + pad = grid_size * grid_size - ch + + # take first item in batch + latent_img = latent_img[0] # shape: (ch, h, w) + + if pad > 0: + padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device) + latent_img = torch.cat([latent_img, padding], dim=0) + + # make grid + new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device) + for x in range(grid_size): + for y in range(grid_size): + if x * grid_size + y < ch: + new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y] + latent_img = new_img + # make rgb + latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0) + latent_img = (latent_img / 2 + 0.5).clamp(0, 1) + + # resize to 256x256 + latent_img = torch.nn.functional.interpolate(latent_img, size=(self.sample_resolution, self.sample_resolution), mode='nearest') + latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy() + latent_img = (latent_img * 255).astype(np.uint8) + # convert to pillow image + latent_img = Image.fromarray(latent_img) + + if target_latent is not None: + latent = target_latent.to(latent.device, dtype=latent.dtype) + + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latent = latent / self.vae.config['scaling_factor'] + shift + + decoded = self.vae.decode(latent).sample + decoded = (decoded / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + # convert to pillow image + decoded = Image.fromarray((decoded * 255).astype(np.uint8)) + + # stack input image and decoded image + input_img = input_img.resize((self.sample_resolution, self.sample_resolution)) + decoded = decoded.resize((self.sample_resolution, self.sample_resolution)) + + output_img = Image.new('RGB', (self.sample_resolution * 3, self.sample_resolution)) + output_img.paste(input_img, (0, 0)) + output_img.paste(decoded, (self.sample_resolution, 0)) + output_img.paste(latent_img, (self.sample_resolution * 2, 0)) + + scale_up = 2 + if output_img.height <= 300: + scale_up = 4 + if output_img.height >= 1000: + scale_up = 1 + + # scale up using nearest neighbor + output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST) + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + output_img.save(os.path.join(sample_folder, filename)) + + def load_vae(self): + path_to_load = self.vae_path + # see if we have a checkpoint in out output to resume from + self.print(f"Looking for latest checkpoint in {self.save_root}") + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + # todo update step and epoch count + else: + self.print(f" - No checkpoint found, starting from scratch") + # load vae + self.print(f"Loading VAE") + self.print(f" - Loading VAE: {path_to_load}") + if self.vae is None: + if path_to_load is not None: + self.vae = self.VaeClass.from_pretrained(path_to_load) + elif self.vae_config is not None: + self.vae = self.VaeClass(**self.vae_config) + else: + raise ValueError('vae_path or ae_config must be specified') + + # set decoder to train + self.vae.to(self.device, dtype=self.torch_dtype) + if self.eq_vae: + self.vae.encoder.train() + else: + self.vae.requires_grad_(False) + self.vae.eval() + self.vae.decoder.train() + self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + if self.target_latent_vae_path is not None: + self.print(f"Loading target latent VAE from {self.target_latent_vae_path}") + self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path) + self.target_latent_vae.to(self.device, dtype=torch.float32) + self.target_latent_vae.eval() + self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1) + else: + self.target_latent_vae = None + self.target_vae_scale_factor = self.vae_scale_factor + + def run(self): + super().run() + self.load_datasets() + + max_step_epochs = self.max_steps // len(self.data_loader) + num_epochs = self.epochs + if num_epochs is None or num_epochs > max_step_epochs: + num_epochs = max_step_epochs + + max_epoch_steps = len(self.data_loader) * num_epochs + num_steps = self.max_steps + if num_steps is None or num_steps > max_epoch_steps: + num_steps = max_epoch_steps + self.max_steps = num_steps + self.epochs = num_epochs + start_step = self.step_num + self.first_step = start_step + + self.print(f"Training VAE") + self.print(f" - Training folder: {self.training_folder}") + self.print(f" - Batch size: {self.batch_size}") + self.print(f" - Learning rate: {self.learning_rate}") + self.print(f" - Epochs: {num_epochs}") + self.print(f" - Max steps: {self.max_steps}") + + # load vae + self.load_vae() + + params = [] + + # only set last 2 layers to trainable + for param in self.vae.decoder.parameters(): + param.requires_grad = False + + train_all = 'all' in self.blocks_to_train + + if train_all: + params = list(self.vae.decoder.named_parameters()) + self.vae.decoder.requires_grad_(True) + if self.train_encoder: + # encoder + params += list(self.vae.encoder.named_parameters()) + self.vae.encoder.requires_grad_(True) + else: + # mid_block + if train_all or 'mid_block' in self.blocks_to_train: + params += list(self.vae.decoder.mid_block.named_parameters()) + self.vae.decoder.mid_block.requires_grad_(True) + # up_blocks + if train_all or 'up_blocks' in self.blocks_to_train: + params += list(self.vae.decoder.up_blocks.named_parameters()) + self.vae.decoder.up_blocks.requires_grad_(True) + # conv_out (single conv layer output) + if train_all or 'conv_out' in self.blocks_to_train: + params += list(self.vae.decoder.conv_out.named_parameters()) + self.vae.decoder.conv_out.requires_grad_(True) + + if self.style_weight > 0 or self.content_weight > 0: + self.setup_vgg19() + # self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + + if self.use_critic: + self.critic.setup() + + if self.clip_weight > 0: + self.setup_clip() + + if self.lpips_weight > 0 and self.lpips_loss is None: + # self.lpips_loss = lpips.LPIPS(net='vgg') + self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16) + + if self.only_if_contains is not None: + orig_params = params + params = [] + for name, param in orig_params: + for contains in self.only_if_contains: + if contains in name: + params.append(param) + break + + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + + # setup scheduler + # todo allow other schedulers + scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, + total_iters=num_steps, + factor=1, + # verbose=False + ) + + # setup tqdm progress bar + self.progress_bar = tqdm( + total=num_steps, + desc='Training VAE', + leave=True + ) + + # sample first + self.sample() + blank_losses = OrderedDict({ + "total": [], + "lpips": [], + "style": [], + "content": [], + "mse": [], + "mae": [], + "lat_mse": [], + "mvl": [], + "ltv": [], + "lpm": [], + "kl": [], + "tv": [], + "clip": [], + "pool": [], + "ptn": [], + "crD": [], + "crG": [], + }) + epoch_losses = copy.deepcopy(blank_losses) + log_losses = copy.deepcopy(blank_losses) + # range start at self.epoch_num go to self.epochs + + latent_size = self.resolution // self.vae_scale_factor + + for epoch in range(self.epoch_num, self.epochs, 1): + if self.step_num >= self.max_steps: + break + for batch in self.data_loader: + if self.step_num >= self.max_steps: + break + with torch.no_grad(): + batch = batch.to(self.device, dtype=self.torch_dtype) + + if self.random_scaling: + # only random scale 0.5 of the time + if random.random() < 0.5: + # random scale the batch + scale_factor = 0.25 + else: + scale_factor = 0.5 + new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor)) + # make sure it is vae divisible + new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor, + new_size[1] // self.vae_scale_factor * self.vae_scale_factor) + + + # resize so it matches size of vae evenly + if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: + batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor, + batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) + + target_latent = None + lat_mse_loss = torch.tensor(0.0, device=self.device) + if self.target_latent_vae is not None: + target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor + target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale)) + # resize to target input size + target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32) + target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach() + # shift scale it + shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0 + target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift) + target_latent = target_latent.to(self.device, dtype=self.torch_dtype) + + + # forward pass + # grad only if eq_vae + with torch.set_grad_enabled(self.train_encoder): + if self.vae_type != 'AutoencoderKL': + # AutoencoderTiny cannot do latent distribution sampling + latents = self.vae.encode(batch, return_dict=False)[0] + mu, logvar = None, None + else: + dgd = self.vae.encode(batch).latent_dist + mu, logvar = dgd.mean, dgd.logvar + latents = dgd.sample() + + # scale shift latent to config + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latents = self.vae.config['scaling_factor'] * (latents - shift) + + if target_latent is not None and self.train_encoder: + # forward_latents = target_latent.detach() + lat_mse_loss = torch.nn.MSELoss()(target_latent.float(), latents.float()) + latents = target_latent.detach() + forward_latents = target_latent.detach() + + elif self.eq_vae: + # process flips, rotate, scale + latent_chunks = list(latents.chunk(latents.shape[0], dim=0)) + batch_chunks = list(batch.chunk(batch.shape[0], dim=0)) + out_chunks = [] + for i in range(len(latent_chunks)): + try: + do_rotate = random.randint(0, 3) + do_flip_x = random.randint(0, 1) + do_flip_y = random.randint(0, 1) + do_scale = random.randint(0, 1) + if do_rotate > 0: + latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3)) + batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3)) + if do_flip_x > 0: + latent_chunks[i] = torch.flip(latent_chunks[i], [2]) + batch_chunks[i] = torch.flip(batch_chunks[i], [2]) + if do_flip_y > 0: + latent_chunks[i] = torch.flip(latent_chunks[i], [3]) + batch_chunks[i] = torch.flip(batch_chunks[i], [3]) + + # resize latent to fit + if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size: + latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False) + + # if do_scale > 0: + # scale = 2 + # start_latent_h = latent_chunks[i].shape[2] + # start_latent_w = latent_chunks[i].shape[3] + # start_batch_h = batch_chunks[i].shape[2] + # start_batch_w = batch_chunks[i].shape[3] + # latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) + # batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) + # # random crop. latent is smaller than match but crops need to match + # latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h) + # latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w) + # batch_x = latent_x * self.vae_scale_factor + # batch_y = latent_y * self.vae_scale_factor + + # # crop + # latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w] + # batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w] + except Exception as e: + print(f"Error processing image {i}: {e}") + traceback.print_exc() + raise e + out_chunks.append(latent_chunks[i]) + latents = torch.cat(out_chunks, dim=0) + # do dropout + if self.dropout > 0: + forward_latents = channel_dropout(latents, self.dropout) + else: + forward_latents = latents + + # resize batch to resolution if needed + if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution: + batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks] + batch = torch.cat(batch_chunks, dim=0) + + else: + # latents.detach().requires_grad_(True) + forward_latents = latents + + forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype) + + if not self.train_encoder: + # detach latents if not training encoder + forward_latents = forward_latents.detach() + + # shift latents to match vae config + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift + + pooled_outputs = None + if self.do_pooled_exits: + pred, pooled_outputs = self.vae.decode_with_pooled_exits(forward_latents) + else: + pred = self.vae.decode(forward_latents).sample + + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0: + stacked = torch.cat([pred, batch], dim=0) + stacked = (stacked / 2 + 0.5).clamp(0, 1) + self.vgg_19(stacked) + + if self.use_critic: + stacked = torch.cat([pred, batch], dim=0) + critic_d_loss = self.critic.step(stacked.detach()) + else: + critic_d_loss = 0.0 + + style_loss = self.get_style_loss() * self.style_weight + content_loss = self.get_content_loss() * self.content_weight + kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight + mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight + mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight + pool_loss = self.get_pooled_output_loss(pooled_outputs, batch) + if self.clip_weight > 0: + clip_loss = self.get_clip_loss(pred, batch) * self.clip_weight + else: + clip_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + if self.lpips_weight > 0: + lpips_loss = self.lpips_loss( + pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16), + batch.clamp(-1, 1).to(self.device, dtype=torch.bfloat16) + ).float().mean() * self.lpips_weight + else: + lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight + pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight + if self.use_critic: + stacked = torch.cat([pred, batch], dim=0) + critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight + + # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it + if self.lpips_weight > 0: + max_target = lpips_loss.abs() * 0.1 + with torch.no_grad(): + crit_g_scaler = 1.0 + if critic_gen_loss.abs() > max_target: + crit_g_scaler = max_target / critic_gen_loss.abs() + + critic_gen_loss *= crit_g_scaler + else: + critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.mv_loss_weight > 0: + mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight + else: + mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.ltv_weight > 0: + ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight + else: + ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.lpm_weight > 0: + lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight + else: + lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss + clip_loss + pool_loss + + # check if loss is NaN or Inf + if torch.isnan(loss) or torch.isinf(loss): + self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}") + self.print(f" - Style loss: {style_loss.item()}") + self.print(f" - Content loss: {content_loss.item()}") + self.print(f" - KLD loss: {kld_loss.item()}") + self.print(f" - MSE loss: {mse_loss.item()}") + self.print(f" - MAE loss: {mae_loss.item()}") + self.print(f" - Latent MSE loss: {lat_mse_loss.item()}") + self.print(f" - LPIPS loss: {lpips_loss.item()}") + self.print(f" - TV loss: {tv_loss.item()}") + self.print(f" - Pattern loss: {pattern_loss.item()}") + self.print(f" - CLIP loss: {clip_loss.item()}") + self.print(f" - Pooled output loss: {pool_loss.item()}") + self.print(f" - Critic gen loss: {critic_gen_loss.item()}") + self.print(f" - Critic D loss: {critic_d_loss}") + self.print(f" - Mean variance loss: {mv_loss.item()}") + self.print(f" - Latent TV loss: {ltv_loss.item()}") + self.print(f" - Latent pixel matching loss: {lpm_loss.item()}") + self.print(f" - Total loss: {loss.item()}") + self.print(f" - Stopping training") + exit(1) + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + # update progress bar + loss_value = loss.item() + # get exponent like 3.54e-4 + loss_string = f"loss: {loss_value:.2e}" + if self.lpips_weight > 0: + loss_string += f" lpips: {lpips_loss.item():.2e}" + if self.content_weight > 0: + loss_string += f" cnt: {content_loss.item():.2e}" + if self.style_weight > 0: + loss_string += f" sty: {style_loss.item():.2e}" + if self.kld_weight > 0: + loss_string += f" kld: {kld_loss.item():.2e}" + if self.mse_weight > 0: + loss_string += f" mse: {mse_loss.item():.2e}" + if self.mae_weight > 0: + loss_string += f" mae: {mae_loss.item():.2e}" + if self.target_latent_vae: + loss_string += f" lat_mse: {lat_mse_loss.item():.2e}" + if self.tv_weight > 0: + loss_string += f" tv: {tv_loss.item():.2e}" + if self.pattern_weight > 0: + loss_string += f" ptn: {pattern_loss.item():.2e}" + if self.clip_weight > 0: + loss_string += f" clip: {clip_loss.item():.2e}" + if self.do_pooled_exits: + loss_string += f" pool: {pool_loss.item():.2e}" + if self.use_critic and self.critic_weight > 0: + loss_string += f" crG: {critic_gen_loss.item():.2e}" + if self.use_critic: + loss_string += f" crD: {critic_d_loss:.2e}" + if self.mv_loss_weight > 0: + loss_string += f" mvl: {mv_loss:.2e}" + if self.ltv_weight > 0: + loss_string += f" ltv: {ltv_loss:.2e}" + if self.lpm_weight > 0: + loss_string += f" lpm: {lpm_loss:.2e}" + + + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif self.optimizer_type.startswith('dadaptation') or \ + self.optimizer_type.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + lr_critic_string = '' + if self.use_critic: + lr_critic = self.critic.get_lr() + lr_critic_string = f" lrC: {lr_critic:.1e}" + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") + self.progress_bar.set_description(f"E: {epoch}") + self.progress_bar.update(1) + + epoch_losses["total"].append(loss_value) + epoch_losses["lpips"].append(lpips_loss.item()) + epoch_losses["style"].append(style_loss.item()) + epoch_losses["content"].append(content_loss.item()) + epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["mae"].append(mae_loss.item()) + epoch_losses["lat_mse"].append(lat_mse_loss.item()) + epoch_losses["kl"].append(kld_loss.item()) + epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["ptn"].append(pattern_loss.item()) + epoch_losses["clip"].append(clip_loss.item()) + epoch_losses["pool"].append(pool_loss.item()) + epoch_losses["crG"].append(critic_gen_loss.item()) + epoch_losses["crD"].append(critic_d_loss) + epoch_losses["mvl"].append(mv_loss.item()) + epoch_losses["ltv"].append(ltv_loss.item()) + epoch_losses["lpm"].append(lpm_loss.item()) + + log_losses["total"].append(loss_value) + log_losses["lpips"].append(lpips_loss.item()) + log_losses["style"].append(style_loss.item()) + log_losses["content"].append(content_loss.item()) + log_losses["mse"].append(mse_loss.item()) + log_losses["mae"].append(mae_loss.item()) + log_losses["lat_mse"].append(lat_mse_loss.item()) + log_losses["kl"].append(kld_loss.item()) + log_losses["tv"].append(tv_loss.item()) + log_losses["ptn"].append(pattern_loss.item()) + log_losses["clip"].append(clip_loss.item()) + log_losses["pool"].append(pool_loss.item()) + log_losses["crG"].append(critic_gen_loss.item()) + log_losses["crD"].append(critic_d_loss) + log_losses["mvl"].append(mv_loss.item()) + log_losses["ltv"].append(ltv_loss.item()) + log_losses["lpm"].append(lpm_loss.item()) + + # don't do on first step + if self.step_num != start_step: + if self.sample_every and self.step_num % self.sample_every == 0: + # print above the progress bar + self.print(f"Sampling at step {self.step_num}") + self.sample(self.step_num) + + if self.save_every and self.step_num % self.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + + if self.log_every and self.step_num % self.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + for key in log_losses: + log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) + # if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) + # reset log losses + log_losses = copy.deepcopy(blank_losses) + + self.step_num += 1 + # end epoch + if self.writer is not None: + eps = 1e-6 + # get avg loss + for key in epoch_losses: + epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) + if epoch_losses[key] > 0: + self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) + # reset epoch losses + epoch_losses = copy.deepcopy(blank_losses) + + self.save() diff --git a/ai-toolkit/jobs/process/__init__.py b/ai-toolkit/jobs/process/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387be08853c3bcb5cbc551d0b1c1c99dad124df6 --- /dev/null +++ b/ai-toolkit/jobs/process/__init__.py @@ -0,0 +1,15 @@ +from .BaseExtractProcess import BaseExtractProcess +from .ExtractLoconProcess import ExtractLoconProcess +from .ExtractLoraProcess import ExtractLoraProcess +from .BaseProcess import BaseProcess +from .BaseTrainProcess import BaseTrainProcess +from .TrainVAEProcess import TrainVAEProcess +from .BaseMergeProcess import BaseMergeProcess +from .TrainSliderProcess import TrainSliderProcess +from .TrainSliderProcessOld import TrainSliderProcessOld +from .TrainSDRescaleProcess import TrainSDRescaleProcess +from .ModRescaleLoraProcess import ModRescaleLoraProcess +from .GenerateProcess import GenerateProcess +from .BaseExtensionProcess import BaseExtensionProcess +from .TrainESRGANProcess import TrainESRGANProcess +from .BaseSDTrainProcess import BaseSDTrainProcess diff --git a/ai-toolkit/jobs/process/models/critic.py b/ai-toolkit/jobs/process/models/critic.py new file mode 100644 index 0000000000000000000000000000000000000000..118db5a0ccc2a39f3396899e9ab6f7f59f9ea3d7 --- /dev/null +++ b/ai-toolkit/jobs/process/models/critic.py @@ -0,0 +1,234 @@ +import glob +import os +from typing import TYPE_CHECKING, Union + +import numpy as np +import torch +import torch.nn as nn +from safetensors.torch import load_file, save_file + +from toolkit.losses import get_gradient_penalty +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.train_tools import get_torch_dtype + + +class MeanReduce(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + # global mean over spatial dims (keeps channel/batch) + return torch.mean(inputs, dim=(2, 3), keepdim=True) + + +class SelfAttention2d(nn.Module): + """ + Lightweight self-attention layer (SAGAN-style) that keeps spatial + resolution unchanged. Adds minimal params / compute but improves + long-range modelling – helpful for variable-sized inputs. + """ + def __init__(self, in_channels: int): + super().__init__() + self.query = nn.Conv1d(in_channels, in_channels // 8, 1) + self.key = nn.Conv1d(in_channels, in_channels // 8, 1) + self.value = nn.Conv1d(in_channels, in_channels, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + B, C, H, W = x.shape + flat = x.view(B, C, H * W) # (B,C,N) + q = self.query(flat).permute(0, 2, 1) # (B,N,C//8) + k = self.key(flat) # (B,C//8,N) + attn = torch.bmm(q, k) # (B,N,N) + attn = attn.softmax(dim=-1) # softmax along last dim + v = self.value(flat) # (B,C,N) + out = torch.bmm(v, attn.permute(0, 2, 1)) # (B,C,N) + out = out.view(B, C, H, W) # restore spatial dims + return self.gamma * out + x # residual + + +class CriticModel(nn.Module): + def __init__(self, base_channels: int = 64): + super().__init__() + + def sn_conv(in_c, out_c, k, s, p): + return nn.utils.spectral_norm( + nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p) + ) + + layers = [ + # initial down-sample + sn_conv(3, base_channels, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ] + + in_c = base_channels + # progressive downsamples ×3 (64→128→256→512) + for _ in range(3): + out_c = min(in_c * 2, 1024) + layers += [ + sn_conv(in_c, out_c, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ] + # single attention block after reaching 256 channels + if out_c == 256: + layers += [SelfAttention2d(out_c)] + in_c = out_c + + # extra depth (keeps spatial size) + layers += [ + sn_conv(in_c, 1024, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + + # final 1-channel prediction map + sn_conv(1024, 1, 3, 1, 1), + MeanReduce(), # → (B,1,1,1) + nn.Flatten(), # → (B,1) + ] + + self.main = nn.Sequential(*layers) + + def forward(self, inputs): + # force full-precision inside AMP ctx for stability + with torch.cuda.amp.autocast(False): + return self.main(inputs.float()) + + +if TYPE_CHECKING: + from jobs.process.TrainVAEProcess import TrainVAEProcess + from jobs.process.TrainESRGANProcess import TrainESRGANProcess + + +class Critic: + process: Union['TrainVAEProcess', 'TrainESRGANProcess'] + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None, + optimizer_params=None, + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + self.model = CriticModel().to(self.device) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer( + params, + self.optimizer_type, + self.learning_rate, + optimizer_params=self.optimizer_params, + ) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + # verbose=False, + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = f"_{str(step).zfill(9)}" if step is not None else '' + save_path = os.path.join( + self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors" + ) + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + # (caller still passes combined [pred|target] images) + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + + self.model.eval() + self.model.requires_grad_(False) + + vgg_pred, _ = torch.chunk(vgg_output.float(), 2, dim=0) + stacked_output = self.model(vgg_pred) + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + self.model.train() + self.model.requires_grad_(True) + self.optimizer.zero_grad() + + critic_losses = [] + inputs = vgg_output.detach().to(self.device, dtype=torch.float32) + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + stacked_output = self.model(inputs).float() + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # hinge loss + gradient penalty + loss_real = torch.relu(1.0 - out_target).mean() + loss_fake = torch.relu(1.0 + out_pred).mean() + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty + + critic_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + return float(np.mean(critic_losses)) + + def get_lr(self): + if hasattr(self.optimizer, 'get_avg_learning_rate'): + learning_rate = self.optimizer.get_avg_learning_rate() + elif self.optimizer_type.startswith('dadaptation') or \ + self.optimizer_type.lower().startswith('prodigy'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + return learning_rate diff --git a/ai-toolkit/jobs/process/models/vgg19_critic.py b/ai-toolkit/jobs/process/models/vgg19_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7f74f9c8f502e21bbe11f9136beeeff75eea41 --- /dev/null +++ b/ai-toolkit/jobs/process/models/vgg19_critic.py @@ -0,0 +1,220 @@ +import glob +import os + +import numpy as np +import torch +import torch.nn as nn +from safetensors.torch import load_file, save_file + +from toolkit.losses import get_gradient_penalty +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.train_tools import get_torch_dtype + +from typing import TYPE_CHECKING, Union + + +class MeanReduce(nn.Module): + def __init__(self): + super(MeanReduce, self).__init__() + + def forward(self, inputs): + return torch.mean(inputs, dim=(1, 2, 3), keepdim=True) + + +class Vgg19Critic(nn.Module): + def __init__(self): + # vgg19 input (bs, 3, 512, 512) + # pool1 (bs, 64, 256, 256) + # pool2 (bs, 128, 128, 128) + # pool3 (bs, 256, 64, 64) + # pool4 (bs, 512, 32, 32) <- take this input + + super(Vgg19Critic, self).__init__() + self.main = nn.Sequential( + # input (bs, 512, 32, 32) + # nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( # SN keeps D’s scale in check + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1) + ), + nn.LeakyReLU(0.2), # (bs, 512, 16, 16) + # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1) + ), + nn.LeakyReLU(0.2), # (bs, 512, 8, 8) + # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1) + ), + # (bs, 1, 4, 4) + MeanReduce(), # (bs, 1, 1, 1) + nn.Flatten(), # (bs, 1) + + # nn.Flatten(), # (128*8*8) = 8192 + # nn.Linear(128 * 8 * 8, 1) + ) + + def forward(self, inputs): + # return self.main(inputs) + with torch.cuda.amp.autocast(False): + return self.main(inputs.float()) + + +if TYPE_CHECKING: + from jobs.process.TrainVAEProcess import TrainVAEProcess + from jobs.process.TrainESRGANProcess import TrainESRGANProcess + + +class Critic: + process: Union['TrainVAEProcess', 'TrainESRGANProcess'] + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None, + optimizer_params=None, + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + self.model = Vgg19Critic().to(self.device) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + verbose=False + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(f" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors") + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + # we need a warmup when we come on of 1000 steps + # we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + # set model to not train for generator loss + self.model.eval() + self.model.requires_grad_(False) + # vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + vgg_pred, vgg_target = torch.chunk(vgg_output.float(), 2, dim=0) + + # run model + stacked_output = self.model(vgg_pred) + + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + + # train critic here + self.model.train() + self.model.requires_grad_(True) + self.optimizer.zero_grad() + + critic_losses = [] + # inputs = vgg_output.detach() + # inputs = inputs.to(self.device, dtype=self.torch_dtype) + inputs = vgg_output.detach().to(self.device, dtype=torch.float32) + self.optimizer.zero_grad() + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + + # stacked_output = self.model(inputs).float() + # out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # # Compute gradient penalty + # gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + # # Compute WGAN-GP critic loss + # critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + + stacked_output = self.model(inputs).float() + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # ── hinge loss ── + loss_real = torch.relu(1.0 - out_target).mean() + loss_fake = torch.relu(1.0 + out_pred).mean() + + # gradient penalty (unchanged helper) + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty + + critic_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + # avg loss + loss = np.mean(critic_losses) + return loss + + def get_lr(self): + if self.optimizer_type.startswith('dadaptation'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + + return learning_rate + diff --git a/ai-toolkit/notebooks/FLUX_1_dev_LoRA_Training.ipynb b/ai-toolkit/notebooks/FLUX_1_dev_LoRA_Training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8cfcd1fedfc941ac1a050f39499f77d303e23783 --- /dev/null +++ b/ai-toolkit/notebooks/FLUX_1_dev_LoRA_Training.ipynb @@ -0,0 +1,291 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "zl-S0m3pkQC5" + }, + "source": [ + "# AI Toolkit by Ostris\n", + "## FLUX.1-dev Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BvAG0GKAh59G" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/ostris/ai-toolkit\n", + "!mkdir -p /content/dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UFUW4ZMmnp1V" + }, + "source": [ + "Put your image dataset in the `/content/dataset` folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGZqVER_aQJW" + }, + "outputs": [], + "source": [ + "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OV0HnOI6o8V6" + }, + "source": [ + "## Model License\n", + "Training currently only works with FLUX.1-dev. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license.\n", + "\n", + "Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\n", + "\n", + "[Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and place it in the next cell after running it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3yZZdhFRoj2m" + }, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# Prompt for the token\n", + "hf_token = getpass.getpass('Enter your HF access token and press enter: ')\n", + "\n", + "# Set the environment variable\n", + "os.environ['HF_TOKEN'] = hf_token\n", + "\n", + "print(\"HF_TOKEN environment variable has been set.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9gO2EzQ1kQC8" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('/content/ai-toolkit')\n", + "from toolkit.job import run_job\n", + "from collections import OrderedDict\n", + "from PIL import Image\n", + "import os\n", + "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8UUFzVRigbC" + }, + "source": [ + "## Setup\n", + "\n", + "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_t28QURYjRQO" + }, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "job_to_run = OrderedDict([\n", + " ('job', 'extension'),\n", + " ('config', OrderedDict([\n", + " # this name will be the folder and filename name\n", + " ('name', 'my_first_flux_lora_v1'),\n", + " ('process', [\n", + " OrderedDict([\n", + " ('type', 'sd_trainer'),\n", + " # root folder to save training sessions/samples/weights\n", + " ('training_folder', '/content/output'),\n", + " # uncomment to see performance stats in the terminal every N steps\n", + " #('performance_log_every', 1000),\n", + " ('device', 'cuda:0'),\n", + " # if a trigger word is specified, it will be added to captions of training data if it does not already exist\n", + " # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word\n", + " # ('trigger_word', 'image'),\n", + " ('network', OrderedDict([\n", + " ('type', 'lora'),\n", + " ('linear', 16),\n", + " ('linear_alpha', 16)\n", + " ])),\n", + " ('save', OrderedDict([\n", + " ('dtype', 'float16'), # precision to save\n", + " ('save_every', 250), # save every this many steps\n", + " ('max_step_saves_to_keep', 4) # how many intermittent saves to keep\n", + " ])),\n", + " ('datasets', [\n", + " # datasets are a folder of images. captions need to be txt files with the same name as the image\n", + " # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently\n", + " # images will automatically be resized and bucketed into the resolution specified\n", + " OrderedDict([\n", + " ('folder_path', '/content/dataset'),\n", + " ('caption_ext', 'txt'),\n", + " ('caption_dropout_rate', 0.05), # will drop out the caption 5% of time\n", + " ('shuffle_tokens', False), # shuffle caption order, split by commas\n", + " ('cache_latents_to_disk', True), # leave this true unless you know what you're doing\n", + " ('resolution', [512, 768, 1024]) # flux enjoys multiple resolutions\n", + " ])\n", + " ]),\n", + " ('train', OrderedDict([\n", + " ('batch_size', 1),\n", + " ('steps', 2000), # total number of steps to train 500 - 4000 is a good range\n", + " ('gradient_accumulation_steps', 1),\n", + " ('train_unet', True),\n", + " ('train_text_encoder', False), # probably won't work with flux\n", + " ('content_or_style', 'balanced'), # content, style, balanced\n", + " ('gradient_checkpointing', True), # need the on unless you have a ton of vram\n", + " ('noise_scheduler', 'flowmatch'), # for training only\n", + " ('optimizer', 'adamw8bit'),\n", + " ('lr', 1e-4),\n", + "\n", + " # uncomment this to skip the pre training sample\n", + " # ('skip_first_sample', True),\n", + "\n", + " # uncomment to completely disable sampling\n", + " # ('disable_sampling', True),\n", + "\n", + " # uncomment to use new vell curved weighting. Experimental but may produce better results\n", + " # ('linear_timesteps', True),\n", + "\n", + " # ema will smooth out learning, but could slow it down. Recommended to leave on.\n", + " ('ema_config', OrderedDict([\n", + " ('use_ema', True),\n", + " ('ema_decay', 0.99)\n", + " ])),\n", + "\n", + " # will probably need this if gpu supports it for flux, other dtypes may not work correctly\n", + " ('dtype', 'bf16')\n", + " ])),\n", + " ('model', OrderedDict([\n", + " # huggingface model name or path\n", + " ('name_or_path', 'black-forest-labs/FLUX.1-dev'),\n", + " ('is_flux', True),\n", + " ('quantize', True), # run 8bit mixed precision\n", + " #('low_vram', True), # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.\n", + " ])),\n", + " ('sample', OrderedDict([\n", + " ('sampler', 'flowmatch'), # must match train.noise_scheduler\n", + " ('sample_every', 250), # sample every this many steps\n", + " ('width', 1024),\n", + " ('height', 1024),\n", + " ('prompts', [\n", + " # you can add [trigger] to the prompts here and it will be replaced with the trigger word\n", + " #'[trigger] holding a sign that says \\'I LOVE PROMPTS!\\'',\n", + " 'woman with red hair, playing chess at the park, bomb going off in the background',\n", + " 'a woman holding a coffee cup, in a beanie, sitting at a cafe',\n", + " 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',\n", + " 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',\n", + " 'a bear building a log cabin in the snow covered mountains',\n", + " 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',\n", + " 'hipster man with a beard, building a chair, in a wood shop',\n", + " 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',\n", + " 'a man holding a sign that says, \\'this is a sign\\'',\n", + " 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle'\n", + " ]),\n", + " ('neg', ''), # not used on flux\n", + " ('seed', 42),\n", + " ('walk_seed', True),\n", + " ('guidance_scale', 4),\n", + " ('sample_steps', 20)\n", + " ]))\n", + " ])\n", + " ])\n", + " ])),\n", + " # you can add any additional meta info here. [name] is replaced with config name at top\n", + " ('meta', OrderedDict([\n", + " ('name', '[name]'),\n", + " ('version', '1.0')\n", + " ]))\n", + "])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h6F1FlM2Wb3l" + }, + "source": [ + "## Run it\n", + "\n", + "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. They will be in /content/output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HkajwI8gteOh" + }, + "outputs": [], + "source": [ + "run_job(job_to_run)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hblgb5uwW5SD" + }, + "source": [ + "## Done\n", + "\n", + "Check your ourput dir and get your slider\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/ai-toolkit/notebooks/FLUX_1_schnell_LoRA_Training.ipynb b/ai-toolkit/notebooks/FLUX_1_schnell_LoRA_Training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..652d8ccc19f8996734785182ce8de46a5c7408fb --- /dev/null +++ b/ai-toolkit/notebooks/FLUX_1_schnell_LoRA_Training.ipynb @@ -0,0 +1,296 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "id": "zl-S0m3pkQC5" + }, + "source": [ + "# AI Toolkit by Ostris\n", + "## FLUX.1-schnell Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3cokMT-WC6rG" + }, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "BvAG0GKAh59G" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/ostris/ai-toolkit\n", + "!mkdir -p /content/dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UFUW4ZMmnp1V" + }, + "source": [ + "Put your image dataset in the `/content/dataset` folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "XGZqVER_aQJW" + }, + "outputs": [], + "source": [ + "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OV0HnOI6o8V6" + }, + "source": [ + "## Model License\n", + "Training currently only works with FLUX.1-dev. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license.\n", + "\n", + "Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\n", + "\n", + "[Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and place it in the next cell after running it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3yZZdhFRoj2m" + }, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# Prompt for the token\n", + "hf_token = getpass.getpass('Enter your HF access token and press enter: ')\n", + "\n", + "# Set the environment variable\n", + "os.environ['HF_TOKEN'] = hf_token\n", + "\n", + "print(\"HF_TOKEN environment variable has been set.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "9gO2EzQ1kQC8" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('/content/ai-toolkit')\n", + "from toolkit.job import run_job\n", + "from collections import OrderedDict\n", + "from PIL import Image\n", + "import os\n", + "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N8UUFzVRigbC" + }, + "source": [ + "## Setup\n", + "\n", + "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "_t28QURYjRQO" + }, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "job_to_run = OrderedDict([\n", + " ('job', 'extension'),\n", + " ('config', OrderedDict([\n", + " # this name will be the folder and filename name\n", + " ('name', 'my_first_flux_lora_v1'),\n", + " ('process', [\n", + " OrderedDict([\n", + " ('type', 'sd_trainer'),\n", + " # root folder to save training sessions/samples/weights\n", + " ('training_folder', '/content/output'),\n", + " # uncomment to see performance stats in the terminal every N steps\n", + " #('performance_log_every', 1000),\n", + " ('device', 'cuda:0'),\n", + " # if a trigger word is specified, it will be added to captions of training data if it does not already exist\n", + " # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word\n", + " # ('trigger_word', 'image'),\n", + " ('network', OrderedDict([\n", + " ('type', 'lora'),\n", + " ('linear', 16),\n", + " ('linear_alpha', 16)\n", + " ])),\n", + " ('save', OrderedDict([\n", + " ('dtype', 'float16'), # precision to save\n", + " ('save_every', 250), # save every this many steps\n", + " ('max_step_saves_to_keep', 4) # how many intermittent saves to keep\n", + " ])),\n", + " ('datasets', [\n", + " # datasets are a folder of images. captions need to be txt files with the same name as the image\n", + " # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently\n", + " # images will automatically be resized and bucketed into the resolution specified\n", + " OrderedDict([\n", + " ('folder_path', '/content/dataset'),\n", + " ('caption_ext', 'txt'),\n", + " ('caption_dropout_rate', 0.05), # will drop out the caption 5% of time\n", + " ('shuffle_tokens', False), # shuffle caption order, split by commas\n", + " ('cache_latents_to_disk', True), # leave this true unless you know what you're doing\n", + " ('resolution', [512, 768, 1024]) # flux enjoys multiple resolutions\n", + " ])\n", + " ]),\n", + " ('train', OrderedDict([\n", + " ('batch_size', 1),\n", + " ('steps', 2000), # total number of steps to train 500 - 4000 is a good range\n", + " ('gradient_accumulation_steps', 1),\n", + " ('train_unet', True),\n", + " ('train_text_encoder', False), # probably won't work with flux\n", + " ('gradient_checkpointing', True), # need the on unless you have a ton of vram\n", + " ('noise_scheduler', 'flowmatch'), # for training only\n", + " ('optimizer', 'adamw8bit'),\n", + " ('lr', 1e-4),\n", + "\n", + " # uncomment this to skip the pre training sample\n", + " # ('skip_first_sample', True),\n", + "\n", + " # uncomment to completely disable sampling\n", + " # ('disable_sampling', True),\n", + "\n", + " # uncomment to use new vell curved weighting. Experimental but may produce better results\n", + " # ('linear_timesteps', True),\n", + "\n", + " # ema will smooth out learning, but could slow it down. Recommended to leave on.\n", + " ('ema_config', OrderedDict([\n", + " ('use_ema', True),\n", + " ('ema_decay', 0.99)\n", + " ])),\n", + "\n", + " # will probably need this if gpu supports it for flux, other dtypes may not work correctly\n", + " ('dtype', 'bf16')\n", + " ])),\n", + " ('model', OrderedDict([\n", + " # huggingface model name or path\n", + " ('name_or_path', 'black-forest-labs/FLUX.1-schnell'),\n", + " ('assistant_lora_path', 'ostris/FLUX.1-schnell-training-adapter'), # Required for flux schnell training\n", + " ('is_flux', True),\n", + " ('quantize', True), # run 8bit mixed precision\n", + " # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary\n", + " #('low_vram', True), # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.\n", + " ])),\n", + " ('sample', OrderedDict([\n", + " ('sampler', 'flowmatch'), # must match train.noise_scheduler\n", + " ('sample_every', 250), # sample every this many steps\n", + " ('width', 1024),\n", + " ('height', 1024),\n", + " ('prompts', [\n", + " # you can add [trigger] to the prompts here and it will be replaced with the trigger word\n", + " #'[trigger] holding a sign that says \\'I LOVE PROMPTS!\\'',\n", + " 'woman with red hair, playing chess at the park, bomb going off in the background',\n", + " 'a woman holding a coffee cup, in a beanie, sitting at a cafe',\n", + " 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',\n", + " 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',\n", + " 'a bear building a log cabin in the snow covered mountains',\n", + " 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',\n", + " 'hipster man with a beard, building a chair, in a wood shop',\n", + " 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',\n", + " 'a man holding a sign that says, \\'this is a sign\\'',\n", + " 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle'\n", + " ]),\n", + " ('neg', ''), # not used on flux\n", + " ('seed', 42),\n", + " ('walk_seed', True),\n", + " ('guidance_scale', 1), # schnell does not do guidance\n", + " ('sample_steps', 4) # 1 - 4 works well\n", + " ]))\n", + " ])\n", + " ])\n", + " ])),\n", + " # you can add any additional meta info here. [name] is replaced with config name at top\n", + " ('meta', OrderedDict([\n", + " ('name', '[name]'),\n", + " ('version', '1.0')\n", + " ]))\n", + "])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h6F1FlM2Wb3l" + }, + "source": [ + "## Run it\n", + "\n", + "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. They will be in /content/output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HkajwI8gteOh" + }, + "outputs": [], + "source": [ + "run_job(job_to_run)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hblgb5uwW5SD" + }, + "source": [ + "## Done\n", + "\n", + "Check your ourput dir and get your slider\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/ai-toolkit/notebooks/SliderTraining.ipynb b/ai-toolkit/notebooks/SliderTraining.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8465ec87dc2d2dce8f11e122c28c80297e3ea2b9 --- /dev/null +++ b/ai-toolkit/notebooks/SliderTraining.ipynb @@ -0,0 +1,339 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm", + "gpuType": "V100" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# AI Toolkit by Ostris\n", + "## Slider Training\n", + "\n", + "This is a quick colab demo for training sliders like can be found in my CivitAI profile https://civitai.com/user/Ostris/models . I will work on making it more user friendly, but for now, it will get you started." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/ostris/ai-toolkit" + ], + "metadata": { + "id": "BvAG0GKAh59G" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGZqVER_aQJW" + }, + "outputs": [], + "source": [ + "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n" + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('/content/ai-toolkit')\n", + "from toolkit.job import run_job\n", + "from collections import OrderedDict\n", + "from PIL import Image" + ], + "metadata": { + "collapsed": false + }, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Setup\n", + "\n", + "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want." + ], + "metadata": { + "id": "N8UUFzVRigbC" + } + }, + { + "cell_type": "code", + "source": [ + "from collections import OrderedDict\n", + "\n", + "job_to_run = OrderedDict({\n", + " # This is the config I use on my sliders, It is solid and tested\n", + " 'job': 'train',\n", + " 'config': {\n", + " # the name will be used to create a folder in the output folder\n", + " # it will also replace any [name] token in the rest of this config\n", + " 'name': 'detail_slider_v1',\n", + " # folder will be created with name above in folder below\n", + " # it can be relative to the project root or absolute\n", + " 'training_folder': \"output/LoRA\",\n", + " 'device': 'cuda', # cpu, cuda:0, etc\n", + " # for tensorboard logging, we will make a subfolder for this job\n", + " 'log_dir': \"output/.tensorboard\",\n", + " # you can stack processes for other jobs, It is not tested with sliders though\n", + " # just use one for now\n", + " 'process': [\n", + " {\n", + " 'type': 'slider', # tells runner to run the slider process\n", + " # network is the LoRA network for a slider, I recommend to leave this be\n", + " 'network': {\n", + " 'type': \"lora\",\n", + " # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good\n", + " 'linear': 8, # \"rank\" or \"dim\"\n", + " 'linear_alpha': 4, # Do about half of rank \"alpha\"\n", + " # 'conv': 4, # for convolutional layers \"locon\"\n", + " # 'conv_alpha': 4, # Do about half of conv \"alpha\"\n", + " },\n", + " # training config\n", + " 'train': {\n", + " # this is also used in sampling. Stick with ddpm unless you know what you are doing\n", + " 'noise_scheduler': \"ddpm\", # or \"ddpm\", \"lms\", \"euler_a\"\n", + " # how many steps to train. More is not always better. I rarely go over 1000\n", + " 'steps': 100,\n", + " # I have had good results with 4e-4 to 1e-4 at 500 steps\n", + " 'lr': 2e-4,\n", + " # enables gradient checkpoint, saves vram, leave it on\n", + " 'gradient_checkpointing': True,\n", + " # train the unet. I recommend leaving this true\n", + " 'train_unet': True,\n", + " # train the text encoder. I don't recommend this unless you have a special use case\n", + " # for sliders we are adjusting representation of the concept (unet),\n", + " # not the description of it (text encoder)\n", + " 'train_text_encoder': False,\n", + "\n", + " # just leave unless you know what you are doing\n", + " # also supports \"dadaptation\" but set lr to 1 if you use that,\n", + " # but it learns too fast and I don't recommend it\n", + " 'optimizer': \"adamw\",\n", + " # only constant for now\n", + " 'lr_scheduler': \"constant\",\n", + " # we randomly denoise random num of steps form 1 to this number\n", + " # while training. Just leave it\n", + " 'max_denoising_steps': 40,\n", + " # works great at 1. I do 1 even with my 4090.\n", + " # higher may not work right with newer single batch stacking code anyway\n", + " 'batch_size': 1,\n", + " # bf16 works best if your GPU supports it (modern)\n", + " 'dtype': 'bf16', # fp32, bf16, fp16\n", + " # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX\n", + " # although, the way we train sliders is comparative, so it probably won't work anyway\n", + " 'noise_offset': 0.0,\n", + " },\n", + "\n", + " # the model to train the LoRA network on\n", + " 'model': {\n", + " # name_or_path can be a hugging face name, local path or url to model\n", + " # on civit ai with or without modelVersionId. They will be cached in /model folder\n", + " # epicRealisim v5\n", + " 'name_or_path': \"https://civitai.com/models/25694?modelVersionId=134065\",\n", + " 'is_v2': False, # for v2 models\n", + " 'is_v_pred': False, # for v-prediction models (most v2 models)\n", + " # has some issues with the dual text encoder and the way we train sliders\n", + " # it works bit weights need to probably be higher to see it.\n", + " 'is_xl': False, # for SDXL models\n", + " },\n", + "\n", + " # saving config\n", + " 'save': {\n", + " 'dtype': 'float16', # precision to save. I recommend float16\n", + " 'save_every': 50, # save every this many steps\n", + " # this will remove step counts more than this number\n", + " # allows you to save more often in case of a crash without filling up your drive\n", + " 'max_step_saves_to_keep': 2,\n", + " },\n", + "\n", + " # sampling config\n", + " 'sample': {\n", + " # must match train.noise_scheduler, this is not used here\n", + " # but may be in future and in other processes\n", + " 'sampler': \"ddpm\",\n", + " # sample every this many steps\n", + " 'sample_every': 20,\n", + " # image size\n", + " 'width': 512,\n", + " 'height': 512,\n", + " # prompts to use for sampling. Do as many as you want, but it slows down training\n", + " # pick ones that will best represent the concept you are trying to adjust\n", + " # allows some flags after the prompt\n", + " # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive\n", + " # slide are good tests. will inherit sample.network_multiplier if not set\n", + " # --n [string] # negative prompt, will inherit sample.neg if not set\n", + " # Only 75 tokens allowed currently\n", + " # I like to do a wide positive and negative spread so I can see a good range and stop\n", + " # early if the network is braking down\n", + " 'prompts': [\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5\",\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3\",\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3\",\n", + " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5\",\n", + " \"a golden retriever sitting on a leather couch, --m -5\",\n", + " \"a golden retriever sitting on a leather couch --m -3\",\n", + " \"a golden retriever sitting on a leather couch --m 3\",\n", + " \"a golden retriever sitting on a leather couch --m 5\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3\",\n", + " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5\",\n", + " ],\n", + " # negative prompt used on all prompts above as default if they don't have one\n", + " 'neg': \"cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome\",\n", + " # seed for sampling. 42 is the answer for everything\n", + " 'seed': 42,\n", + " # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc\n", + " # will start over on next sample_every so s1 is always seed\n", + " # works well if you use same prompt but want different results\n", + " 'walk_seed': False,\n", + " # cfg scale (4 to 10 is good)\n", + " 'guidance_scale': 7,\n", + " # sampler steps (20 to 30 is good)\n", + " 'sample_steps': 20,\n", + " # default network multiplier for all prompts\n", + " # since we are training a slider, I recommend overriding this with --m [number]\n", + " # in the prompts above to get both sides of the slider\n", + " 'network_multiplier': 1.0,\n", + " },\n", + "\n", + " # logging information\n", + " 'logging': {\n", + " 'log_every': 10, # log every this many steps\n", + " 'use_wandb': False, # not supported yet\n", + " 'verbose': False, # probably done need unless you are debugging\n", + " },\n", + "\n", + " # slider training config, best for last\n", + " 'slider': {\n", + " # resolutions to train on. [ width, height ]. This is less important for sliders\n", + " # as we are not teaching the model anything it doesn't already know\n", + " # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1\n", + " # and [ 1024, 1024 ] for sd_xl\n", + " # you can do as many as you want here\n", + " 'resolutions': [\n", + " [512, 512],\n", + " # [ 512, 768 ]\n", + " # [ 768, 768 ]\n", + " ],\n", + " # slider training uses 4 combined steps for a single round. This will do it in one gradient\n", + " # step. It is highly optimized and shouldn't take anymore vram than doing without it,\n", + " # since we break down batches for gradient accumulation now. so just leave it on.\n", + " 'batch_full_slide': True,\n", + " # These are the concepts to train on. You can do as many as you want here,\n", + " # but they can conflict outweigh each other. Other than experimenting, I recommend\n", + " # just doing one for good results\n", + " 'targets': [\n", + " # target_class is the base concept we are adjusting the representation of\n", + " # for example, if we are adjusting the representation of a person, we would use \"person\"\n", + " # if we are adjusting the representation of a cat, we would use \"cat\" It is not\n", + " # a keyword necessarily but what the model understands the concept to represent.\n", + " # \"person\" will affect men, women, children, etc but will not affect cats, dogs, etc\n", + " # it is the models base general understanding of the concept and everything it represents\n", + " # you can leave it blank to affect everything. In this example, we are adjusting\n", + " # detail, so we will leave it blank to affect everything\n", + " {\n", + " 'target_class': \"\",\n", + " # positive is the prompt for the positive side of the slider.\n", + " # It is the concept that will be excited and amplified in the model when we slide the slider\n", + " # to the positive side and forgotten / inverted when we slide\n", + " # the slider to the negative side. It is generally best to include the target_class in\n", + " # the prompt. You want it to be the extreme of what you want to train on. For example,\n", + " # if you want to train on fat people, you would use \"an extremely fat, morbidly obese person\"\n", + " # as the prompt. Not just \"fat person\"\n", + " # max 75 tokens for now\n", + " 'positive': \"high detail, 8k, intricate, detailed, high resolution, high res, high quality\",\n", + " # negative is the prompt for the negative side of the slider and works the same as positive\n", + " # it does not necessarily work the same as a negative prompt when generating images\n", + " # these need to be polar opposites.\n", + " # max 76 tokens for now\n", + " 'negative': \"blurry, boring, fuzzy, low detail, low resolution, low res, low quality\",\n", + " # the loss for this target is multiplied by this number.\n", + " # if you are doing more than one target it may be good to set less important ones\n", + " # to a lower number like 0.1 so they don't outweigh the primary target\n", + " 'weight': 1.0,\n", + " },\n", + " ],\n", + " },\n", + " },\n", + " ]\n", + " },\n", + "\n", + " # You can put any information you want here, and it will be saved in the model.\n", + " # The below is an example, but you can put your grocery list in it if you want.\n", + " # It is saved in the model so be aware of that. The software will include this\n", + " # plus some other information for you automatically\n", + " 'meta': {\n", + " # [name] gets replaced with the name above\n", + " 'name': \"[name]\",\n", + " 'version': '1.0',\n", + " # 'creator': {\n", + " # 'name': 'your name',\n", + " # 'email': 'your@gmail.com',\n", + " # 'website': 'https://your.website'\n", + " # }\n", + " }\n", + "})\n" + ], + "metadata": { + "id": "_t28QURYjRQO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run it\n", + "\n", + "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. Ill update soon." + ], + "metadata": { + "id": "h6F1FlM2Wb3l" + } + }, + { + "cell_type": "code", + "source": [ + "run_job(job_to_run)\n" + ], + "metadata": { + "id": "HkajwI8gteOh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Done\n", + "\n", + "Check your ourput dir and get your slider\n" + ], + "metadata": { + "id": "Hblgb5uwW5SD" + } + } + ] +} diff --git a/ai-toolkit/output/.gitkeep b/ai-toolkit/output/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/requirements.txt b/ai-toolkit/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..54f3856219dfd58c1952824fcc3c27999a20ad03 --- /dev/null +++ b/ai-toolkit/requirements.txt @@ -0,0 +1,2 @@ +-r requirements_base.txt +scipy==1.12.0 diff --git a/ai-toolkit/requirements_base.txt b/ai-toolkit/requirements_base.txt new file mode 100644 index 0000000000000000000000000000000000000000..726ccb284372c3e645c1aea47eb47c241bfca410 --- /dev/null +++ b/ai-toolkit/requirements_base.txt @@ -0,0 +1,42 @@ +torchao==0.10.0 +safetensors +diffusers==0.32.2 +transformers==5.5.3 +lycoris-lora==1.8.3 +flatten_json +pyyaml +oyaml +tensorboard +kornia +invisible-watermark +einops +accelerate +toml +albumentations==1.4.15 +albucore==0.0.16 +pydantic +omegaconf +k-diffusion +open_clip_torch +timm==1.0.22 +prodigyopt +controlnet_aux==0.0.10 +python-dotenv +bitsandbytes +hf_transfer +lpips +pytorch_fid +optimum-quanto==0.2.4 +sentencepiece +huggingface_hub==1.10.1 +peft==0.18.1 +gradio +python-slugify +opencv-python +pytorch-wavelets==1.3.0 +matplotlib==3.10.1 +setuptools==69.5.1 +av==16.0.1 +torchcodec==0.9.1 +librosa==0.11.0 +mutagen==1.47.0 diff --git a/ai-toolkit/run.py b/ai-toolkit/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5f3de4326438d9f2798b62f6d1e4fca4d48e17 --- /dev/null +++ b/ai-toolkit/run.py @@ -0,0 +1,150 @@ +import os +import sys +sys.path.insert(0, r'D:\AI_Training\spock_lora') +try: + import spock_compat_shim # noqa: F401 # Spock fork compat shim +except Exception as _e: + print(f"[spock_compat_shim] warn: {_e}") +try: + import torchao_compat # noqa: F401 # backport UIntXWeightOnlyConfig for torchao 0.17+ +except Exception as _e: + print(f"[torchao_compat] warn: {_e}") +from dotenv import load_dotenv +# Load the .env file if it exists +load_dotenv() +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = os.getenv("HF_HUB_ENABLE_HF_TRANSFER", "1") +os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" +seed = None +if "SEED" in os.environ: + try: + seed = int(os.environ["SEED"]) + except ValueError: + print(f"Invalid SEED value: {os.environ['SEED']}. SEED must be an integer.") + +sys.path.insert(0, os.getcwd()) +# must come before ANY torch or fastai imports +# import toolkit.cuda_malloc + +# turn off diffusers telemetry until I can figure out how to make it opt-in +os.environ['DISABLE_TELEMETRY'] = 'YES' + +# set torch to trace mode +import torch + +# Spock fork: pin CPU thread count to match 16-core Ryzen 9 9950X. +# Without this, PyTorch defaults to the OS's view which can underutilize. +torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "16"))) +torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "16"))) +print(f"[spock_fork] torch threads: {torch.get_num_threads()} intraop, {torch.get_num_interop_threads()} interop") + +# check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + torch.autograd.set_detect_anomaly(True) + +if seed is not None: + import random + import numpy as np + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +import argparse +from toolkit.job import get_job +from toolkit.accelerator import get_accelerator +from toolkit.print import print_acc, setup_log_to_file + +accelerator = get_accelerator() + + +def print_end_message(jobs_completed, jobs_failed): + if not accelerator.is_main_process: + return + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print_acc("") + print_acc("========================================") + print_acc("Result:") + if len(completed_string) > 0: + print_acc(f" - {completed_string}") + if len(failure_string) > 0: + print_acc(f" - {failure_string}") + print_acc("========================================") + + +def main(): + parser = argparse.ArgumentParser() + + # require at lease one config file + parser.add_argument( + 'config_file_list', + nargs='+', + type=str, + help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + ) + + # flag to continue if failed job + parser.add_argument( + '-r', '--recover', + action='store_true', + help='Continue running additional jobs even if a job fails' + ) + + # flag to continue if failed job + parser.add_argument( + '-n', '--name', + type=str, + default=None, + help='Name to replace [name] tag in config file, useful for shared config file' + ) + + parser.add_argument( + '-l', '--log', + type=str, + default=None, + help='Log file to write output to' + ) + args = parser.parse_args() + + if args.log is not None: + setup_log_to_file(args.log) + + config_file_list = args.config_file_list + if len(config_file_list) == 0: + raise Exception("You must provide at least one config file") + + jobs_completed = 0 + jobs_failed = 0 + + if accelerator.is_main_process: + print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + + for config_file in config_file_list: + try: + job = get_job(config_file, args.name) + job.run() + job.cleanup() + jobs_completed += 1 + except Exception as e: + print_acc(f"Error running job: {e}") + jobs_failed += 1 + try: + job.process[0].on_error(e) + except Exception as e2: + print_acc(f"Error running on_error: {e2}") + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e + except KeyboardInterrupt as e: + try: + job.process[0].on_error(e) + except Exception as e2: + print_acc(f"Error running on_error: {e2}") + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + +if __name__ == '__main__': + main() diff --git a/ai-toolkit/run.py.bak.20260627-144458 b/ai-toolkit/run.py.bak.20260627-144458 new file mode 100644 index 0000000000000000000000000000000000000000..b54929fbae9161f8d6522704bf3cb64108db46fd --- /dev/null +++ b/ai-toolkit/run.py.bak.20260627-144458 @@ -0,0 +1,146 @@ +import os +import sys +sys.path.insert(0, r'D:\AI_Training\spock_lora') +try: + import spock_compat_shim # noqa: F401 # Spock fork compat shim +except Exception as _e: + print(f"[spock_compat_shim] warn: {_e}") +from dotenv import load_dotenv +# Load the .env file if it exists +load_dotenv() +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = os.getenv("HF_HUB_ENABLE_HF_TRANSFER", "1") +os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" +seed = None +if "SEED" in os.environ: + try: + seed = int(os.environ["SEED"]) + except ValueError: + print(f"Invalid SEED value: {os.environ['SEED']}. SEED must be an integer.") + +sys.path.insert(0, os.getcwd()) +# must come before ANY torch or fastai imports +# import toolkit.cuda_malloc + +# turn off diffusers telemetry until I can figure out how to make it opt-in +os.environ['DISABLE_TELEMETRY'] = 'YES' + +# set torch to trace mode +import torch + +# Spock fork: pin CPU thread count to match 16-core Ryzen 9 9950X. +# Without this, PyTorch defaults to the OS's view which can underutilize. +torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "16"))) +torch.set_num_interop_threads(int(os.environ.get("TORCH_NUM_INTEROP_THREADS", "16"))) +print(f"[spock_fork] torch threads: {torch.get_num_threads()} intraop, {torch.get_num_interop_threads()} interop") + +# check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + torch.autograd.set_detect_anomaly(True) + +if seed is not None: + import random + import numpy as np + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +import argparse +from toolkit.job import get_job +from toolkit.accelerator import get_accelerator +from toolkit.print import print_acc, setup_log_to_file + +accelerator = get_accelerator() + + +def print_end_message(jobs_completed, jobs_failed): + if not accelerator.is_main_process: + return + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print_acc("") + print_acc("========================================") + print_acc("Result:") + if len(completed_string) > 0: + print_acc(f" - {completed_string}") + if len(failure_string) > 0: + print_acc(f" - {failure_string}") + print_acc("========================================") + + +def main(): + parser = argparse.ArgumentParser() + + # require at lease one config file + parser.add_argument( + 'config_file_list', + nargs='+', + type=str, + help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + ) + + # flag to continue if failed job + parser.add_argument( + '-r', '--recover', + action='store_true', + help='Continue running additional jobs even if a job fails' + ) + + # flag to continue if failed job + parser.add_argument( + '-n', '--name', + type=str, + default=None, + help='Name to replace [name] tag in config file, useful for shared config file' + ) + + parser.add_argument( + '-l', '--log', + type=str, + default=None, + help='Log file to write output to' + ) + args = parser.parse_args() + + if args.log is not None: + setup_log_to_file(args.log) + + config_file_list = args.config_file_list + if len(config_file_list) == 0: + raise Exception("You must provide at least one config file") + + jobs_completed = 0 + jobs_failed = 0 + + if accelerator.is_main_process: + print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + + for config_file in config_file_list: + try: + job = get_job(config_file, args.name) + job.run() + job.cleanup() + jobs_completed += 1 + except Exception as e: + print_acc(f"Error running job: {e}") + jobs_failed += 1 + try: + job.process[0].on_error(e) + except Exception as e2: + print_acc(f"Error running on_error: {e2}") + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e + except KeyboardInterrupt as e: + try: + job.process[0].on_error(e) + except Exception as e2: + print_acc(f"Error running on_error: {e2}") + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + +if __name__ == '__main__': + main() diff --git a/ai-toolkit/run_mac.zsh b/ai-toolkit/run_mac.zsh new file mode 100644 index 0000000000000000000000000000000000000000..8f7b49e9e7306f1ced1da8bdf6be65b3fd20fa0e --- /dev/null +++ b/ai-toolkit/run_mac.zsh @@ -0,0 +1,166 @@ +#!/usr/bin/env zsh +# Update-and-run script for macOS — portable Python 3.12 + PyTorch +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +# ── Banner ───────────────────────────────────────────────────────── +echo "" +echo "\033[36m" +cat << 'BANNER' + _ ___ _____ _ _ _ _ + / \ |_ _| |_ _| ___ ___ | || | __(_)| |_ + / _ \ | | | | / _ \ / _ \| || |/ /| || __| + / ___ \ | | | | | (_) || (_) | || < | || |_ + /_/ \_\|___| |_| \___/ \___/|_||_|\_\|_| \__| +BANNER +echo "\033[0m" +echo "\033[90m macOS Setup & Launcher\033[0m" +echo "" +VENV_DIR="$SCRIPT_DIR/.venv" +PIP="$VENV_DIR/bin/pip" +PYTHON="$VENV_DIR/bin/python3" +PYTHON_VERSION="3.12.8" +RELEASE_TAG="20241219" + +# --- Package versions (update these as needed) --- +NODE_VERSION="23.11.1" +TORCH_VERSION="2.11.0" +TORCHVISION_VERSION="0.26.0" +TORCHAUDIO_VERSION="2.11.0" + +# Detect architecture +ARCH="$(uname -m)" +if [[ "$ARCH" == "arm64" ]]; then + PLATFORM="aarch64-apple-darwin" +elif [[ "$ARCH" == "x86_64" ]]; then + PLATFORM="x86_64-apple-darwin" +else + echo "Error: Unsupported architecture: $ARCH" + exit 1 +fi + +# ── 1. Download standalone Python if needed ───────────────────────── +PYTHON_DIR="$SCRIPT_DIR/.python" +PYTHON_BIN="$PYTHON_DIR/bin/python3" + +if [[ ! -x "$PYTHON_BIN" ]]; then + TARBALL="cpython-${PYTHON_VERSION}+${RELEASE_TAG}-${PLATFORM}-install_only.tar.gz" + URL="https://github.com/indygreg/python-build-standalone/releases/download/${RELEASE_TAG}/${TARBALL}" + + TMPDIR_DL="$(mktemp -d)" + trap 'rm -rf "$TMPDIR_DL"' EXIT + + echo "Downloading standalone Python ${PYTHON_VERSION} (${PLATFORM})..." + curl -fSL --progress-bar -o "$TMPDIR_DL/$TARBALL" "$URL" + + echo "Extracting..." + tar -xzf "$TMPDIR_DL/$TARBALL" -C "$TMPDIR_DL" + + # Move to permanent location (the archive extracts to a "python" folder) + rm -rf "$PYTHON_DIR" + mv "$TMPDIR_DL/python" "$PYTHON_DIR" + + rm -rf "$TMPDIR_DL" + trap - EXIT + + echo "Standalone Python installed to $PYTHON_DIR" +fi + +# ── 2. Create venv if it doesn't exist ────────────────────────────── +if [[ ! -d "$VENV_DIR" ]]; then + echo "Creating virtual environment at $VENV_DIR..." + "$PYTHON_BIN" -m venv "$VENV_DIR" + echo "Virtual environment created." +fi + +# ── 3. Download / update portable Node.js ────────────────────────── +NODE_DIR="$SCRIPT_DIR/.node" +NODE_BIN="$NODE_DIR/bin/node" + +NEED_NODE=false +if [[ ! -x "$NODE_BIN" ]]; then + NEED_NODE=true +elif [[ "$("$NODE_BIN" --version 2>/dev/null)" != "v${NODE_VERSION}" ]]; then + echo "Node.js version mismatch (want v${NODE_VERSION}, have $("$NODE_BIN" --version))." + NEED_NODE=true +fi + +if $NEED_NODE; then + if [[ "$ARCH" == "arm64" ]]; then + NODE_ARCH="arm64" + else + NODE_ARCH="x64" + fi + + NODE_TARBALL="node-v${NODE_VERSION}-darwin-${NODE_ARCH}.tar.gz" + NODE_URL="https://nodejs.org/dist/v${NODE_VERSION}/${NODE_TARBALL}" + + TMPDIR_DL="$(mktemp -d)" + trap 'rm -rf "$TMPDIR_DL"' EXIT + + echo "Downloading Node.js v${NODE_VERSION} (darwin-${NODE_ARCH})..." + curl -fSL --progress-bar -o "$TMPDIR_DL/$NODE_TARBALL" "$NODE_URL" + + echo "Extracting..." + tar -xzf "$TMPDIR_DL/$NODE_TARBALL" -C "$TMPDIR_DL" + + rm -rf "$NODE_DIR" + mv "$TMPDIR_DL/node-v${NODE_VERSION}-darwin-${NODE_ARCH}" "$NODE_DIR" + + rm -rf "$TMPDIR_DL" + trap - EXIT + + echo "Node.js v${NODE_VERSION} installed to $NODE_DIR" +else + echo "Node.js v${NODE_VERSION} is up to date." +fi + +# ── 4. Install / update PyTorch packages ──────────────────────────── +# Helper: returns 0 if the package is installed at the exact version +pkg_ok() { + local pkg="$1" want="$2" + local got + got="$("$PIP" show "$pkg" 2>/dev/null | awk '/^Version:/{print $2}')" || true + [[ "$got" == "$want" ]] +} + +PKGS_TO_INSTALL=() + +pkg_ok "torch" "$TORCH_VERSION" || PKGS_TO_INSTALL+=("torch==$TORCH_VERSION") +pkg_ok "torchvision" "$TORCHVISION_VERSION" || PKGS_TO_INSTALL+=("torchvision==$TORCHVISION_VERSION") +pkg_ok "torchaudio" "$TORCHAUDIO_VERSION" || PKGS_TO_INSTALL+=("torchaudio==$TORCHAUDIO_VERSION") + +if (( ${#PKGS_TO_INSTALL[@]} )); then + echo "Installing / updating: ${PKGS_TO_INSTALL[*]}" + "$PIP" install "${PKGS_TO_INSTALL[@]}" +else + echo "PyTorch packages are up to date." +fi + +# ── 5. Install / update requirements.txt ──────────────────────────── +REQUIREMENTS="$SCRIPT_DIR/requirements.txt" +REQ_HASH_FILE="$VENV_DIR/.requirements_hash" + +if [[ -f "$REQUIREMENTS" ]]; then + # Hash all requirements files (follows -r includes) + CURRENT_HASH="$(cat "$SCRIPT_DIR"/requirements*.txt 2>/dev/null | shasum -a 256 | awk '{print $1}')" + STORED_HASH="" + [[ -f "$REQ_HASH_FILE" ]] && STORED_HASH="$(cat "$REQ_HASH_FILE")" + + if [[ "$CURRENT_HASH" != "$STORED_HASH" ]]; then + echo "Installing / updating requirements.txt..." + "$PIP" install -r "$REQUIREMENTS" + echo "$CURRENT_HASH" > "$REQ_HASH_FILE" + else + echo "Requirements are up to date." + fi +fi + +# ── 6. Build and start the UI ─────────────────────────────────────── +export PATH="$NODE_DIR/bin:$VENV_DIR/bin:$PATH" + +echo "" +echo "Starting UI..." +cd "$SCRIPT_DIR/ui" +npm run build_and_start diff --git a/ai-toolkit/run_modal.py b/ai-toolkit/run_modal.py new file mode 100644 index 0000000000000000000000000000000000000000..4675c1cb8ec709126317dcba02315177df777f68 --- /dev/null +++ b/ai-toolkit/run_modal.py @@ -0,0 +1,175 @@ +''' + +ostris/ai-toolkit on https://modal.com +Run training with the following command: +modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml + +''' + +import os +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import sys +import modal +from dotenv import load_dotenv +# Load the .env file if it exists +load_dotenv() + +sys.path.insert(0, "/root/ai-toolkit") +# must come before ANY torch or fastai imports +# import toolkit.cuda_malloc + +# turn off diffusers telemetry until I can figure out how to make it opt-in +os.environ['DISABLE_TELEMETRY'] = 'YES' + +# define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes +# you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models +model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True) + +# modal_output, due to "cannot mount volume on non-empty path" requirement +MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement + +# define modal app +image = ( + modal.Image.debian_slim(python_version="3.11") + # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app + .apt_install("libgl1", "libglib2.0-0") + .pip_install( + "python-dotenv", + "torch", + "diffusers[torch]", + "transformers", + "ftfy", + "torchvision", + "oyaml", + "opencv-python", + "albumentations", + "safetensors", + "lycoris-lora==1.8.3", + "flatten_json", + "pyyaml", + "tensorboard", + "kornia", + "invisible-watermark", + "einops", + "accelerate", + "toml", + "pydantic", + "omegaconf", + "k-diffusion", + "open_clip_torch", + "timm", + "prodigyopt", + "controlnet_aux==0.0.7", + "bitsandbytes", + "hf_transfer", + "lpips", + "pytorch_fid", + "optimum-quanto", + "sentencepiece", + "huggingface_hub", + "peft" + ) +) + +# mount for the entire ai-toolkit directory +# example: "/Users/username/ai-toolkit" is the local directory, "/root/ai-toolkit" is the remote directory +code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit") + +# create the Modal app with the necessary mounts and volumes +app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume}) + +# Check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + # Set torch to trace mode + import torch + torch.autograd.set_detect_anomaly(True) + +import argparse +from toolkit.job import get_job + +def print_end_message(jobs_completed, jobs_failed): + failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" + completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" + + print("") + print("========================================") + print("Result:") + if len(completed_string) > 0: + print(f" - {completed_string}") + if len(failure_string) > 0: + print(f" - {failure_string}") + print("========================================") + + +@app.function( + # request a GPU with at least 24GB VRAM + # more about modal GPU's: https://modal.com/docs/guide/gpu + gpu="A100", # gpu="H100" + # more about modal timeouts: https://modal.com/docs/guide/timeouts + timeout=7200 # 2 hours, increase or decrease if needed +) +def main(config_file_list_str: str, recover: bool = False, name: str = None): + # convert the config file list from a string to a list + config_file_list = config_file_list_str.split(",") + + jobs_completed = 0 + jobs_failed = 0 + + print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + + for config_file in config_file_list: + try: + job = get_job(config_file, name) + + job.config['process'][0]['training_folder'] = MOUNT_DIR + os.makedirs(MOUNT_DIR, exist_ok=True) + print(f"Training outputs will be saved to: {MOUNT_DIR}") + + # run the job + job.run() + + # commit the volume after training + model_volume.commit() + + job.cleanup() + jobs_completed += 1 + except Exception as e: + print(f"Error running job: {e}") + jobs_failed += 1 + if not recover: + print_end_message(jobs_completed, jobs_failed) + raise e + + print_end_message(jobs_completed, jobs_failed) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # require at least one config file + parser.add_argument( + 'config_file_list', + nargs='+', + type=str, + help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + ) + + # flag to continue if a job fails + parser.add_argument( + '-r', '--recover', + action='store_true', + help='Continue running additional jobs even if a job fails' + ) + + # optional name replacement for config file + parser.add_argument( + '-n', '--name', + type=str, + default=None, + help='Name to replace [name] tag in config file, useful for shared config file' + ) + args = parser.parse_args() + + # convert list of config files to a comma-separated string for Modal compatibility + config_file_list_str = ",".join(args.config_file_list) + + main.call(config_file_list_str=config_file_list_str, recover=args.recover, name=args.name) diff --git a/ai-toolkit/scripts/add_mask_dataset.py b/ai-toolkit/scripts/add_mask_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8dda4863b41c6acddf4b9cf51c0d7a7d88d32f36 --- /dev/null +++ b/ai-toolkit/scripts/add_mask_dataset.py @@ -0,0 +1,256 @@ +import os +import sys +import queue +import random +import argparse +import threading +import traceback + +import torch +from tqdm import tqdm + +# allow importing from project root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from toolkit.control_generator import ControlGenerator, img_ext_list + + +def control_exists(img_path, control_type): + # mirrors the lookup in ControlGenerator.get_control_path so we can skip + # images another instance has already finished + controls_folder = os.path.join(os.path.dirname(img_path), "_controls") + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + file_name_no_ext_control = f"{file_name_no_ext}.{control_type}" + for ext in img_ext_list: + if os.path.exists(os.path.join(controls_folder, file_name_no_ext_control + ext)): + return True + return False + + +# sentinel pushed through the queues to tell workers to stop +_DONE = object() + + +def run_pipeline(control_gen, img_list, control_type, regen, n_load, n_save): + # Three-stage pipeline so the GPU never waits on disk/CPU work: + # loaders (N threads) -> read + exif + resize + preprocess to a CPU tensor + # gpu worker (1 thread) -> model forward only (kept single so VRAM is bounded) + # savers (M threads) -> postprocess (resize/alpha) + encode + write + # The heavy CPU work (resize/normalize on input, resize/convert on output) is + # pushed onto the loader/saver threads so the GPU thread does almost nothing + # but the forward pass. Bounded queues apply backpressure so we don't load the + # whole dataset into RAM. + infer_q = queue.Queue(maxsize=n_load * 2) + save_q = queue.Queue(maxsize=n_save * 2) + path_q = queue.Queue() + for img_path in img_list: + path_q.put(img_path) + + # miniters=1 disables tqdm's dynamic-miniters heuristic (which otherwise + # raises the redraw threshold after a fast burst and makes the bar look + # frozen); mininterval keeps redraws time-based and cheap. + pbar = tqdm(total=len(img_list), desc=f"Generating {control_type}", + miniters=1, mininterval=0.5) + pbar_lock = threading.Lock() + # set on completion OR on Ctrl-C; every blocking call below uses a timeout and + # re-checks this so the worker threads can actually be shut down. + stop_event = threading.Event() + + def put(q, item): + # interruptible put: blocks until there's room, but wakes periodically so + # a stop request (or KeyboardInterrupt on the main thread) is honored. + while not stop_event.is_set(): + try: + q.put(item, timeout=0.2) + return + except queue.Full: + continue + + def loader(): + while not stop_event.is_set(): + try: + img_path = path_q.get_nowait() + except queue.Empty: + break + try: + if not regen and control_exists(img_path, control_type): + # another instance (or a previous run) already did it + with pbar_lock: + pbar.update(1) + continue + image = control_gen.load_image(img_path) + payload = control_gen.preprocess(image, control_type) + put(infer_q, (img_path, image, payload)) + except Exception: + traceback.print_exc() + with pbar_lock: + pbar.update(1) + + def saver(): + while not stop_event.is_set(): + try: + item = save_q.get(timeout=0.2) + except queue.Empty: + continue + if item is _DONE: + break + img_path, image, result = item + try: + out_image = control_gen.postprocess(result, image, control_type) + save_path = control_gen.control_save_path(img_path, control_type) + control_gen.save_control(out_image, save_path) + except Exception: + traceback.print_exc() + finally: + with pbar_lock: + pbar.update(1) + + loaders = [threading.Thread(target=loader, daemon=True) for _ in range(n_load)] + savers = [threading.Thread(target=saver, daemon=True) for _ in range(n_save)] + for t in loaders + savers: + t.start() + + # GPU stage runs on the main thread: pull preprocessed tensors, run the + # forward pass, hand the raw result off to the savers. We stop once every + # loader has exited and nothing is left queued for inference. + interrupted = False + try: + while not stop_event.is_set(): + if not any(t.is_alive() for t in loaders) and infer_q.empty(): + break + try: + img_path, image, payload = infer_q.get(timeout=0.1) + except queue.Empty: + continue + try: + result = control_gen.run_inference(payload, control_type) + put(save_q, (img_path, image, result)) + except Exception: + traceback.print_exc() + with pbar_lock: + pbar.update(1) + except KeyboardInterrupt: + interrupted = True + print("\nInterrupted, shutting down...") + + if interrupted: + # abort: tell every worker to stop; pending items are dropped + stop_event.set() + else: + # normal finish: let savers drain whatever is still queued, then stop + for _ in savers: + save_q.put(_DONE) + + # join with a timeout so a stuck worker can never wedge shutdown; threads are + # daemons, so anything still alive is torn down when we return. + for t in savers: + t.join(timeout=5) + pbar.close() + if interrupted: + raise KeyboardInterrupt + + +def main(): + parser = argparse.ArgumentParser( + description="Generate masks for a dataset using the ControlGenerator" + ) + parser.add_argument("img_dir", type=str, help="Path to image directory") + parser.add_argument( + "--control", + type=str, + default="mask", + choices=["mask", "inpaint", "depth", "pose", "line", "sapiens2_mask"], + help="Control type to generate (default: mask)", + ) + parser.add_argument( + "--device", type=str, default="cuda", help="Device to run on (default: cuda)" + ) + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + parser.add_argument( + "--regen", + action="store_true", + help="Regenerate controls even if they already exist", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="Shuffle image order so multiple instances on the same dataset " + "don't chase the same images", + ) + parser.add_argument( + "--load-workers", + type=int, + default=16, + help="Number of threads for loading/resizing images (default: 4)", + ) + parser.add_argument( + "--save-workers", + type=int, + default=16, + help="Number of threads for saving controls (default: 4)", + ) + args = parser.parse_args() + + img_dir = args.img_dir + if not os.path.isdir(img_dir): + print(f"Error: {img_dir} is not a directory") + sys.exit(1) + + # find images, skipping existing _controls folders and hidden files + img_list = [] + for root, dirs, files in os.walk(img_dir): + if "_controls" in root: + continue + for file in files: + if file.startswith("."): + continue + if file.lower().endswith(tuple(img_ext_list)): + img_list.append(os.path.join(root, file)) + + if len(img_list) == 0: + print(f"Error: no images found in {img_dir}") + sys.exit(1) + + # filter out images that already have controls up front so the progress bar + # reflects only real work (otherwise it races through thousands of skips and + # the rate/ETA are meaningless). The loader still re-checks just before + # processing to handle the multi-instance race. + if not args.regen: + total = len(img_list) + img_list = [p for p in img_list if not control_exists(p, args.control)] + skipped = total - len(img_list) + if skipped: + print(f"Skipping {skipped} images that already have '{args.control}' controls") + if len(img_list) == 0: + print("All images already have controls. Nothing to do.") + return + + if args.shuffle: + random.shuffle(img_list) + + control_gen = ControlGenerator(torch.device(args.device)) + control_gen.debug = args.debug + control_gen.regen = args.regen + + interrupted = False + try: + run_pipeline( + control_gen, + img_list, + args.control, + args.regen, + max(1, args.load_workers), + max(1, args.save_workers), + ) + except KeyboardInterrupt: + interrupted = True + finally: + control_gen.cleanup() + + if interrupted: + sys.exit(130) + print("Done") + + +if __name__ == "__main__": + main() diff --git a/ai-toolkit/scripts/calculate_timestep_weighing_flex.py b/ai-toolkit/scripts/calculate_timestep_weighing_flex.py new file mode 100644 index 0000000000000000000000000000000000000000..05a217667f760758e6bbc9afe97fe0111b649386 --- /dev/null +++ b/ai-toolkit/scripts/calculate_timestep_weighing_flex.py @@ -0,0 +1,228 @@ +import gc +import os, sys +from tqdm import tqdm +import numpy as np +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# set visible devices to 0 +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# protect from formatting +if True: + import torch + from optimum.quanto import freeze, qfloat8, QTensor, qint4 + from diffusers import FluxTransformer2DModel, FluxPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler + from toolkit.util.quantize import quantize, get_qtype + from transformers import T5EncoderModel, T5TokenizerFast, CLIPTextModel, CLIPTokenizer + from torchvision import transforms + +qtype = "qfloat8" +dtype = torch.bfloat16 +# base_model_path = "black-forest-labs/FLUX.1-dev" +base_model_path = "ostris/Flex.1-alpha" +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Loading Transformer...") +prompt = "Photo of a man and a woman in a park, sunny day" + +output_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output") +output_path = os.path.join(output_root, "flex_timestep_weights.json") +img_output_path = os.path.join(output_root, "flex_timestep_weights.png") + +quantization_type = get_qtype(qtype) + +def flush(): + torch.cuda.empty_cache() + gc.collect() + +pil_to_tensor = transforms.ToTensor() + +with torch.no_grad(): + transformer = FluxTransformer2DModel.from_pretrained( + base_model_path, + subfolder='transformer', + torch_dtype=dtype + ) + + transformer.to(device, dtype=dtype) + + print("Quantizing Transformer...") + quantize(transformer, weights=quantization_type) + freeze(transformer) + flush() + + print("Loading Scheduler...") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + + print("Loading Autoencoder...") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + + vae.to(device, dtype=dtype) + + flush() + print("Loading Text Encoder...") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) + text_encoder_2.to(device, dtype=dtype) + + print("Quantizing Text Encoder...") + quantize(text_encoder_2, weights=get_qtype(qtype)) + freeze(text_encoder_2) + flush() + + print("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(device, dtype=dtype) + + print("Making pipe") + + pipe: FluxPipeline = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + pipe.to(device, dtype=dtype) + + print("Encoding prompt...") + + prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( + prompt, + prompt_2=prompt, + device=device + ) + + + generator = torch.manual_seed(42) + + height = 1024 + width = 1024 + + print("Generating image...") + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"] + if latents.dtype != dtype: + latents = latents.to(dtype) + return {"latents": latents} + img = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + height=height, + width=height, + num_inference_steps=30, + guidance_scale=3.5, + generator=generator, + callback_on_step_end=callback_on_step_end, + ).images[0] + + img.save(img_output_path) + print(f"Image saved to {img_output_path}") + + print("Encoding image...") + # img is a PIL image. convert it to a -1 to 1 tensor + img = pil_to_tensor(img) + img = img.unsqueeze(0) # add batch dimension + img = img * 2 - 1 # convert to -1 to 1 range + img = img.to(device, dtype=dtype) + latents = vae.encode(img).latent_dist.sample() + + shift = vae.config['shift_factor'] if vae.config['shift_factor'] is not None else 0 + latents = vae.config['scaling_factor'] * (latents - shift) + + num_channels_latents = pipe.transformer.config.in_channels // 4 + + l_height = 2 * (int(height) // (pipe.vae_scale_factor * 2)) + l_width = 2 * (int(width) // (pipe.vae_scale_factor * 2)) + packed_latents = pipe._pack_latents(latents, 1, num_channels_latents, l_height, l_width) + + packed_latents, latent_image_ids = pipe.prepare_latents( + 1, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + packed_latents, + ) + + print("Calculating timestep weights...") + + torch.manual_seed(8675309) + noise = torch.randn_like(packed_latents, device=device, dtype=dtype) + + # Create linear timesteps from 1000 to 0 + num_train_timesteps = 1000 + timesteps_torch = torch.linspace(1000, 1, num_train_timesteps, device='cpu') + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + timestep_weights = torch.zeros(num_train_timesteps, dtype=torch.float32, device=device) + + guidance = torch.full([1], 1.0, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + pbar = tqdm(range(num_train_timesteps), desc="loss: 0.000000 scaler: 0.0000") + for i in pbar: + timestep = timesteps[i:i+1].to(device) + t_01 = (timestep / 1000).to(device) + t_01 = t_01.reshape(-1, 1, 1) + noisy_latents = (1.0 - t_01) * packed_latents + t_01 * noise + + noise_pred = pipe.transformer( + hidden_states=noisy_latents, # torch.Size([1, 4096, 64]) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + target = noise - packed_latents + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float()) + loss = loss + + # determine scaler to multiply loss by to make it 1 + scaler = 1.0 / (loss + 1e-6) + + timestep_weights[i] = scaler + pbar.set_description(f"loss: {loss.item():.6f} scaler: {scaler.item():.4f}") + + print("normalizing timestep weights...") + # normalize the timestep weights so they are a mean of 1.0 + timestep_weights = timestep_weights / timestep_weights.mean() + timestep_weights = timestep_weights.cpu().numpy().tolist() + + print("Saving timestep weights...") + + with open(output_path, 'w') as f: + json.dump(timestep_weights, f) + + +print(f"Timestep weights saved to {output_path}") +print("Done!") +flush() + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-toolkit/scripts/caption_audio_dataset.py b/ai-toolkit/scripts/caption_audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8cddf7131921e5137cb2bbeed34ec7f4ef5d0de9 --- /dev/null +++ b/ai-toolkit/scripts/caption_audio_dataset.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" +Caption audio files for ACE-Step v1.5 training. + +Produces .txt files containing all training metadata: + - caption (from acestep-captioner) + - lyrics (from acestep-transcriber) + - bpm, keyscale, timesignature (from librosa) + - duration, language + +Requirements: + pip install torch torchaudio transformers librosa numpy + +Usage: + python caption_dir.py input_dir/ + python caption_dir.py input_dir/ --low_vram --skip_existing +""" + +import argparse +import gc +import os +import glob +import logging +import warnings + +import librosa +import numpy as np +import torch +import torchaudio +from tqdm import tqdm +from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor + +warnings.filterwarnings("ignore") +logging.disable(logging.WARNING) + +TARGET_SAMPLE_RATE = 16000 +CAPTIONER_ID = "ACE-Step/acestep-captioner" +TRANSCRIBER_ID = "ACE-Step/acestep-transcriber" + +# Key profiles for Krumhansl-Schmuckler key detection +MAJOR_PROFILE = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88]) +MINOR_PROFILE = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]) +KEY_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] + + +def get_audio_files(input_dir): + extensions = ["*.wav", "*.mp3", "*.flac", "*.ogg", "*.WAV", "*.MP3", "*.FLAC"] + files = [] + for ext in extensions: + files.extend(glob.glob(os.path.join(input_dir, ext))) + return sorted(set(files)) + + +def load_audio_mono_16k(audio_path): + waveform, sr = torchaudio.load(audio_path) + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if sr != TARGET_SAMPLE_RATE: + waveform = torchaudio.functional.resample(waveform, sr, TARGET_SAMPLE_RATE) + return waveform.squeeze(0).numpy(), TARGET_SAMPLE_RATE + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Audio analysis (BPM, key, time signature) via librosa +# ═══════════════════════════════════════════════════════════════════════════════ + +def analyze_audio(audio_path): + """Extract BPM, key, and time signature from audio using librosa.""" + y, sr = librosa.load(audio_path, sr=22050, mono=True) + duration = librosa.get_duration(y=y, sr=sr) + + # BPM + tempo, _ = librosa.beat.beat_track(y=y, sr=sr) + if hasattr(tempo, '__len__'): + tempo = tempo[0] + bpm = int(round(float(tempo))) + + # Key detection via chroma correlation with key profiles + chroma = librosa.feature.chroma_cqt(y=y, sr=sr) + chroma_avg = chroma.mean(axis=1) + major_corrs = np.array([np.corrcoef(np.roll(MAJOR_PROFILE, i), chroma_avg)[0, 1] for i in range(12)]) + minor_corrs = np.array([np.corrcoef(np.roll(MINOR_PROFILE, i), chroma_avg)[0, 1] for i in range(12)]) + + best_major_idx = major_corrs.argmax() + best_minor_idx = minor_corrs.argmax() + if major_corrs[best_major_idx] >= minor_corrs[best_minor_idx]: + keyscale = f"{KEY_NAMES[best_major_idx]} major" + else: + keyscale = f"{KEY_NAMES[best_minor_idx]} minor" + + # Time signature estimation from beat strength pattern + onset_env = librosa.onset.onset_strength(y=y, sr=sr) + tempo_est, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr) + if len(beats) >= 8: + beat_strengths = onset_env[beats] + # Check 3/4 vs 4/4 by looking at periodicity of strong beats + acf = np.correlate(beat_strengths - beat_strengths.mean(), + beat_strengths - beat_strengths.mean(), mode='full') + acf = acf[len(acf) // 2:] + if len(acf) > 6: + # Look at autocorrelation peaks at lag 3 vs lag 4 + score_3 = acf[3] if len(acf) > 3 else 0 + score_4 = acf[4] if len(acf) > 4 else 0 + timesig = "3" if score_3 > score_4 * 1.2 else "4" + else: + timesig = "4" + else: + timesig = "4" + + return { + "bpm": bpm, + "keyscale": keyscale, + "timesignature": timesig, + "duration": int(round(duration)), + } + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Model management +# ═══════════════════════════════════════════════════════════════════════════════ + +def offload_to_cpu(model): + """Move model to CPU and free GPU memory.""" + if model is not None: + model.to("cpu") + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def load_qwen_model(model_id, device="cuda", dtype=torch.bfloat16): + """Load a Qwen2.5-Omni model.""" + model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + model_id, torch_dtype=dtype, device_map=device, + ) + model.disable_talker() + processor = Qwen2_5OmniProcessor.from_pretrained(model_id) + return model, processor + + +def run_qwen_audio(model, processor, audio_data, sr, prompt_text): + """Run a Qwen2.5-Omni model on audio with a text prompt.""" + conversation = [ + { + "role": "user", + "content": [ + {"type": "audio", "audio": "<|audio_bos|><|AUDIO|><|audio_eos|>"}, + {"type": "text", "text": prompt_text}, + ], + } + ] + text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + inputs = processor( + text=text, audio=[audio_data], images=None, videos=None, + return_tensors="pt", padding=True, sampling_rate=sr, + ) + inputs = inputs.to(model.device).to(model.dtype) + text_ids = model.generate(**inputs, return_audio=False) + output = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + result = output[0] + marker = "assistant\n" + if marker in result: + result = result[result.rfind(marker) + len(marker):] + return result.strip() + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Output formatting +# ═══════════════════════════════════════════════════════════════════════════════ + +def format_output(caption, lyrics, analysis, language="en"): + """Format all metadata into tagged format for easy parsing.""" + return ( + f"\n{caption}\n\n" + f"\n{lyrics}\n\n" + f"{analysis['bpm']}\n" + f"{analysis['keyscale']}\n" + f"{analysis['timesignature']}\n" + f"{analysis['duration']}\n" + f"{language}" + ) + + +def parse_caption_file(path): + """Parse a tagged caption file back into a dict.""" + import re + text = open(path, "r", encoding="utf-8").read() + def tag(name): + m = re.search(rf"<{name}>(.*?)", text, re.DOTALL) + return m.group(1).strip() if m else "" + return { + "caption": tag("CAPTION"), + "lyrics": tag("LYRICS"), + "bpm": tag("BPM"), + "keyscale": tag("KEYSCALE"), + "timesignature": tag("TIMESIGNATURE"), + "duration": tag("DURATION"), + "language": tag("LANGUAGE"), + } + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════════════════ + +def main(): + parser = argparse.ArgumentParser(description="Caption audio files for ACE-Step training") + parser.add_argument("input_dir", type=str, help="Directory containing audio files") + parser.add_argument("--skip_existing", action="store_true", help="Skip files that already have captions") + parser.add_argument("--low_vram", action="store_true", help="Offload models to CPU between stages") + parser.add_argument("--language", default="en", help="Default language code (default: en)") + args = parser.parse_args() + + if not os.path.isdir(args.input_dir): + print(f"Error: {args.input_dir} is not a valid directory") + return + + audio_files = get_audio_files(args.input_dir) + if not audio_files: + print("No audio files found in the directory") + return + + print(f"Found {len(audio_files)} audio files") + + # ── Stage 1: Audio analysis (BPM, key, time sig) — no GPU needed ───── + print("\n[Stage 1/3] Analyzing audio (BPM, key, time signature)...") + analyses = {} + for audio_path in tqdm(audio_files, desc="Analyzing"): + base_name = os.path.splitext(audio_path)[0] + if args.skip_existing and os.path.exists(base_name + ".txt"): + continue + try: + analyses[audio_path] = analyze_audio(audio_path) + except Exception as e: + print(f"\n Error analyzing {os.path.basename(audio_path)}: {e}") + analyses[audio_path] = {"bpm": 120, "keyscale": "C major", "timesignature": "4", + "duration": 30} + + # Filter to only files that need processing + files_to_process = [f for f in audio_files if f in analyses] + if not files_to_process: + print("All files already captioned (use without --skip_existing to overwrite)") + return + + # ── Stage 2: Captioning ────────────────────────────────────────────── + print(f"\n[Stage 2/3] Captioning {len(files_to_process)} files...") + print(" Loading captioner model...") + captioner, cap_processor = load_qwen_model(CAPTIONER_ID) + + captions = {} + for audio_path in tqdm(files_to_process, desc="Captioning"): + try: + audio_data, sr = load_audio_mono_16k(audio_path) + caption = run_qwen_audio( + captioner, cap_processor, audio_data, sr, + "*Task* Describe this music in detail. Include genre, mood, instrumentation, tempo feel, and vocal style if present." + ) + captions[audio_path] = caption + except Exception as e: + print(f"\n Error captioning {os.path.basename(audio_path)}: {e}") + captions[audio_path] = "" + + if args.low_vram: + print(" Offloading captioner to CPU...") + offload_to_cpu(captioner) + del captioner, cap_processor + + # ── Stage 3: Lyrics transcription ──────────────────────────────────── + print(f"\n[Stage 3/3] Transcribing lyrics for {len(files_to_process)} files...") + print(" Loading transcriber model...") + transcriber, trans_processor = load_qwen_model(TRANSCRIBER_ID) + + lyrics_map = {} + for audio_path in tqdm(files_to_process, desc="Transcribing"): + try: + audio_data, sr = load_audio_mono_16k(audio_path) + lyrics = run_qwen_audio( + transcriber, trans_processor, audio_data, sr, + "*Task* Transcribe this audio in detail" + ) + lyrics_map[audio_path] = lyrics + except Exception as e: + print(f"\n Error transcribing {os.path.basename(audio_path)}: {e}") + lyrics_map[audio_path] = "[Instrumental]" + + if args.low_vram: + print(" Offloading transcriber to CPU...") + offload_to_cpu(transcriber) + del transcriber, trans_processor + + # ── Write output files ─────────────────────────────────────────────── + print("\nWriting output files...") + for audio_path in files_to_process: + base_name = os.path.splitext(audio_path)[0] + output_path = base_name + ".txt" + + caption = captions.get(audio_path, "") + lyrics = lyrics_map.get(audio_path, "[Instrumental]") + analysis = analyses[audio_path] + + output = format_output(caption, lyrics, analysis, args.language) + with open(output_path, "w", encoding="utf-8") as f: + f.write(output) + + print(f"Done! Processed {len(files_to_process)} files.") + + +if __name__ == "__main__": + main() diff --git a/ai-toolkit/scripts/convert_cog.py b/ai-toolkit/scripts/convert_cog.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4f6e73c1d3e444583319b37557ad36ec988ccf --- /dev/null +++ b/ai-toolkit/scripts/convert_cog.py @@ -0,0 +1,128 @@ +import json +from collections import OrderedDict +import os +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +device = torch.device('cpu') + +# [diffusers] -> kohya +embedding_mapping = { + 'text_encoders_0': 'clip_l', + 'text_encoders_1': 'clip_g' +} + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps') +sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json') + +# load keymap +with open(sdxl_keymap_path, 'r') as f: + ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap'] + +# invert the item / key pairs +diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()} + + +def get_ldm_key(diffuser_key): + diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}" + diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight') + diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight') + diffuser_key = diffuser_key.replace('_alpha', '.alpha') + diffuser_key = diffuser_key.replace('_processor_to_', '_to_') + diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.') + if diffuser_key in diffusers_ldm_keymap: + return diffusers_ldm_keymap[diffuser_key] + else: + raise KeyError(f"Key {diffuser_key} not found in keymap") + + +def convert_cog(lora_path, embedding_path): + embedding_state_dict = OrderedDict() + lora_state_dict = OrderedDict() + + # # normal dict + # normal_dict = OrderedDict() + # example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors" + # with safe_open(example_path, framework="pt", device='cpu') as f: + # keys = list(f.keys()) + # for key in keys: + # normal_dict[key] = f.get_tensor(key) + + with safe_open(embedding_path, framework="pt", device='cpu') as f: + keys = list(f.keys()) + for key in keys: + new_key = embedding_mapping[key] + embedding_state_dict[new_key] = f.get_tensor(key) + + with safe_open(lora_path, framework="pt", device='cpu') as f: + keys = list(f.keys()) + lora_rank = None + + # get the lora dim first. Check first 3 linear layers just to be safe + for key in keys: + new_key = get_ldm_key(key) + tensor = f.get_tensor(key) + num_checked = 0 + if len(tensor.shape) == 2: + this_dim = min(tensor.shape) + if lora_rank is None: + lora_rank = this_dim + elif lora_rank != this_dim: + raise ValueError(f"lora rank is not consistent, got {tensor.shape}") + else: + num_checked += 1 + if num_checked >= 3: + break + + for key in keys: + new_key = get_ldm_key(key) + tensor = f.get_tensor(key) + if new_key.endswith('.lora_down.weight'): + alpha_key = new_key.replace('.lora_down.weight', '.alpha') + # diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims + # assume first smallest dim is the lora rank if shape is 2 + lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank + + lora_state_dict[new_key] = tensor + + return lora_state_dict, embedding_state_dict + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + 'lora_path', + type=str, + help='Path to lora file' + ) + parser.add_argument( + 'embedding_path', + type=str, + help='Path to embedding file' + ) + + parser.add_argument( + '--lora_output', + type=str, + default="lora_output", + ) + + parser.add_argument( + '--embedding_output', + type=str, + default="embedding_output", + ) + + args = parser.parse_args() + + lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path) + + # save them + save_file(lora_state_dict, args.lora_output) + save_file(embedding_state_dict, args.embedding_output) + print(f"Saved lora to {args.lora_output}") + print(f"Saved embedding to {args.embedding_output}") diff --git a/ai-toolkit/scripts/convert_diffusers_to_comfy.py b/ai-toolkit/scripts/convert_diffusers_to_comfy.py new file mode 100644 index 0000000000000000000000000000000000000000..28aa7ee8484d9bd97916a3f9450527e0c53fb825 --- /dev/null +++ b/ai-toolkit/scripts/convert_diffusers_to_comfy.py @@ -0,0 +1,426 @@ +####################################################### +# Convert Diffusers Flux/Flex to all in one ComfyUI safetensors file +# The VAE, T5 and clip will all be in the safetensors file +# T5 will always be 8bit with the all in one file +# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag +# +# Download a reference model from Huggingface +# https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors +# +# Call like this for 8-bit transformer weights: +# python convert_flux_diffusers_to_orig.py /path/to/diffusers/checkpoint /path/to/flux1-dev-fp8.safetensors /output/path/my_finetune.safetensors --do_8_bit +# +# Call like this for bf16 transformer weights: +# python convert_flux_diffusers_to_orig.py /path/to/diffusers/checkpoint /path/to/flux1-dev-fp8.safetensors /output/path/my_finetune.safetensors +# +####################################################### + + +import argparse +from datetime import date +import json +import os +from pathlib import Path +import safetensors +import safetensors.torch +import torch +import tqdm +from collections import OrderedDict + + +parser = argparse.ArgumentParser() + +parser.add_argument("diffusers_path", type=str, + help="Path to the original Flux diffusers folder.") +parser.add_argument("quantized_state_dict_path", type=str, + help="Path to the ComfyUI all in one template file.") +parser.add_argument("flux_path", type=str, + help="Output path for the Flux safetensors file.") +parser.add_argument("--do_8_bit", action="store_true", + help="Use 8-bit weights instead of bf16.") +args = parser.parse_args() + +flux_path = Path(args.flux_path) +diffusers_path = Path(args.diffusers_path, "transformer") +quantized_state_dict_path = Path(args.quantized_state_dict_path) + +do_8_bit = args.do_8_bit + +if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + +if not diffusers_path.exists(): + print(f"Error: Missing transformer folder: {diffusers_path}") + exit() + +original_json_path = Path.joinpath( + diffusers_path, "diffusion_pytorch_model.safetensors.index.json") +if not original_json_path.exists(): + print(f"Error: Missing transformer index json: {original_json_path}") + exit() + +if not os.path.exists(quantized_state_dict_path): + print( + f"Error: Missing quantized state dict file: {args.quantized_state_dict_path}") + exit() + +with open(original_json_path, "r", encoding="utf-8") as f: + original_json = json.load(f) + +diffusers_map = { + "time_in.in_layer.weight": [ + "time_text_embed.timestep_embedder.linear_1.weight", + ], + "time_in.in_layer.bias": [ + "time_text_embed.timestep_embedder.linear_1.bias", + ], + "time_in.out_layer.weight": [ + "time_text_embed.timestep_embedder.linear_2.weight", + ], + "time_in.out_layer.bias": [ + "time_text_embed.timestep_embedder.linear_2.bias", + ], + "vector_in.in_layer.weight": [ + "time_text_embed.text_embedder.linear_1.weight", + ], + "vector_in.in_layer.bias": [ + "time_text_embed.text_embedder.linear_1.bias", + ], + "vector_in.out_layer.weight": [ + "time_text_embed.text_embedder.linear_2.weight", + ], + "vector_in.out_layer.bias": [ + "time_text_embed.text_embedder.linear_2.bias", + ], + "guidance_in.in_layer.weight": [ + "time_text_embed.guidance_embedder.linear_1.weight", + ], + "guidance_in.in_layer.bias": [ + "time_text_embed.guidance_embedder.linear_1.bias", + ], + "guidance_in.out_layer.weight": [ + "time_text_embed.guidance_embedder.linear_2.weight", + ], + "guidance_in.out_layer.bias": [ + "time_text_embed.guidance_embedder.linear_2.bias", + ], + "txt_in.weight": [ + "context_embedder.weight", + ], + "txt_in.bias": [ + "context_embedder.bias", + ], + "img_in.weight": [ + "x_embedder.weight", + ], + "img_in.bias": [ + "x_embedder.bias", + ], + "double_blocks.().img_mod.lin.weight": [ + "norm1.linear.weight", + ], + "double_blocks.().img_mod.lin.bias": [ + "norm1.linear.bias", + ], + "double_blocks.().txt_mod.lin.weight": [ + "norm1_context.linear.weight", + ], + "double_blocks.().txt_mod.lin.bias": [ + "norm1_context.linear.bias", + ], + "double_blocks.().img_attn.qkv.weight": [ + "attn.to_q.weight", + "attn.to_k.weight", + "attn.to_v.weight", + ], + "double_blocks.().img_attn.qkv.bias": [ + "attn.to_q.bias", + "attn.to_k.bias", + "attn.to_v.bias", + ], + "double_blocks.().txt_attn.qkv.weight": [ + "attn.add_q_proj.weight", + "attn.add_k_proj.weight", + "attn.add_v_proj.weight", + ], + "double_blocks.().txt_attn.qkv.bias": [ + "attn.add_q_proj.bias", + "attn.add_k_proj.bias", + "attn.add_v_proj.bias", + ], + "double_blocks.().img_attn.norm.query_norm.scale": [ + "attn.norm_q.weight", + ], + "double_blocks.().img_attn.norm.key_norm.scale": [ + "attn.norm_k.weight", + ], + "double_blocks.().txt_attn.norm.query_norm.scale": [ + "attn.norm_added_q.weight", + ], + "double_blocks.().txt_attn.norm.key_norm.scale": [ + "attn.norm_added_k.weight", + ], + "double_blocks.().img_mlp.0.weight": [ + "ff.net.0.proj.weight", + ], + "double_blocks.().img_mlp.0.bias": [ + "ff.net.0.proj.bias", + ], + "double_blocks.().img_mlp.2.weight": [ + "ff.net.2.weight", + ], + "double_blocks.().img_mlp.2.bias": [ + "ff.net.2.bias", + ], + "double_blocks.().txt_mlp.0.weight": [ + "ff_context.net.0.proj.weight", + ], + "double_blocks.().txt_mlp.0.bias": [ + "ff_context.net.0.proj.bias", + ], + "double_blocks.().txt_mlp.2.weight": [ + "ff_context.net.2.weight", + ], + "double_blocks.().txt_mlp.2.bias": [ + "ff_context.net.2.bias", + ], + "double_blocks.().img_attn.proj.weight": [ + "attn.to_out.0.weight", + ], + "double_blocks.().img_attn.proj.bias": [ + "attn.to_out.0.bias", + ], + "double_blocks.().txt_attn.proj.weight": [ + "attn.to_add_out.weight", + ], + "double_blocks.().txt_attn.proj.bias": [ + "attn.to_add_out.bias", + ], + "single_blocks.().modulation.lin.weight": [ + "norm.linear.weight", + ], + "single_blocks.().modulation.lin.bias": [ + "norm.linear.bias", + ], + "single_blocks.().linear1.weight": [ + "attn.to_q.weight", + "attn.to_k.weight", + "attn.to_v.weight", + "proj_mlp.weight", + ], + "single_blocks.().linear1.bias": [ + "attn.to_q.bias", + "attn.to_k.bias", + "attn.to_v.bias", + "proj_mlp.bias", + ], + "single_blocks.().linear2.weight": [ + "proj_out.weight", + ], + "single_blocks.().norm.query_norm.scale": [ + "attn.norm_q.weight", + ], + "single_blocks.().norm.key_norm.scale": [ + "attn.norm_k.weight", + ], + "single_blocks.().linear2.weight": [ + "proj_out.weight", + ], + "single_blocks.().linear2.bias": [ + "proj_out.bias", + ], + "final_layer.linear.weight": [ + "proj_out.weight", + ], + "final_layer.linear.bias": [ + "proj_out.bias", + ], + "final_layer.adaLN_modulation.1.weight": [ + "norm_out.linear.weight", + ], + "final_layer.adaLN_modulation.1.bias": [ + "norm_out.linear.bias", + ], +} + + +def is_in_diffusers_map(k): + for values in diffusers_map.values(): + for value in values: + if k.endswith(value): + return True + return False + + +diffusers = {k: Path.joinpath(diffusers_path, v) + for k, v in original_json["weight_map"].items() if is_in_diffusers_map(k)} + +original_safetensors = set(diffusers.values()) + +# determine the number of transformer blocks +transformer_blocks = 0 +single_transformer_blocks = 0 +for key in diffusers.keys(): + print(key) + if key.startswith("transformer_blocks."): + print(key) + block = int(key.split(".")[1]) + if block >= transformer_blocks: + transformer_blocks = block + 1 + elif key.startswith("single_transformer_blocks."): + block = int(key.split(".")[1]) + if block >= single_transformer_blocks: + single_transformer_blocks = block + 1 + +print(f"Transformer blocks: {transformer_blocks}") +print(f"Single transformer blocks: {single_transformer_blocks}") + +for file in original_safetensors: + if not file.exists(): + print(f"Error: Missing transformer safetensors file: {file}") + exit() + +original_safetensors = {f: safetensors.safe_open( + f, framework="pt", device="cpu") for f in original_safetensors} + + +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +flux_values = {} + +for b in range(transformer_blocks): + for key, weights in diffusers_map.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + found = True + for weight in weights: + if not (f"{block_prefix}{weight}" in diffusers): + found = False + if found: + flux_values[key.replace("()", f"{b}")] = [ + f"{block_prefix}{weight}" for weight in weights] +for b in range(single_transformer_blocks): + for key, weights in diffusers_map.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + found = True + for weight in weights: + if not (f"{block_prefix}{weight}" in diffusers): + found = False + if found: + flux_values[key.replace("()", f"{b}")] = [ + f"{block_prefix}{weight}" for weight in weights] + +for key, weights in diffusers_map.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + found = True + for weight in weights: + if not (f"{weight}" in diffusers): + found = False + if found: + flux_values[key] = [f"{weight}" for weight in weights] + +flux = {} + +for key, values in tqdm.tqdm(flux_values.items()): + if len(values) == 1: + flux[key] = original_safetensors[diffusers[values[0]] + ].get_tensor(values[0]).to("cpu") + else: + flux[key] = torch.cat( + [ + original_safetensors[diffusers[value] + ].get_tensor(value).to("cpu") + for value in values + ] + ) + +if "norm_out.linear.weight" in diffusers: + flux["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + original_safetensors[diffusers["norm_out.linear.weight"]].get_tensor( + "norm_out.linear.weight").to("cpu") + ) +if "norm_out.linear.bias" in diffusers: + flux["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + original_safetensors[diffusers["norm_out.linear.bias"]].get_tensor( + "norm_out.linear.bias").to("cpu") + ) + + +def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn): + # Define the float8 range + min_val = torch.finfo(dtype).min + max_val = torch.finfo(dtype).max + + # Clip values to float8 range + tensor = torch.clamp(tensor, min_val, max_val) + + # Convert to float32 for calculations + tensor = tensor.float() + + # Get the nearest representable float8 values + lower = torch.floor(tensor * 256) / 256 + upper = torch.ceil(tensor * 256) / 256 + + # Calculate the probability of rounding up + prob = (tensor - lower) / (upper - lower) + + # Generate random values for stochastic rounding + rand = torch.rand_like(tensor) + + # Perform stochastic rounding + rounded = torch.where(rand < prob, upper, lower) + + # Convert back to float8 + return rounded.to(dtype) + + +# set all the keys to bf16 +for key in flux.keys(): + if do_8_bit: + flux[key] = stochastic_round_to( + flux[key], torch.float8_e4m3fn).to('cpu') + else: + flux[key] = flux[key].clone().to('cpu', torch.bfloat16) + +# load the quantized state dict +quantized_state_dict = safetensors.torch.load_file(quantized_state_dict_path) + +transformer_pre = "model.diffusion_model." +did_print = False +# remove old parts +for key in list(quantized_state_dict.keys()): + if key.startswith(transformer_pre): + if not did_print: + # print("dtype: ", quantized_state_dict[key].dtype) + did_print = True + del quantized_state_dict[key] + +# add the new parts +for key, value in flux.items(): + quantized_state_dict[transformer_pre + key] = value + + +meta = OrderedDict() +meta['format'] = 'pt' +# date format like 2024-08-01 YYYY-MM-DD +meta['modelspec.date'] = date.today().strftime("%Y-%m-%d") +meta['modelspec.title'] = "Flex.1-alpha" +meta['modelspec.author'] = "Ostris, LLC" +meta['modelspec.license'] = "Apache-2.0" +meta['modelspec.implementation'] = "https://github.com/black-forest-labs/flux" +meta['modelspec.architecture'] = "Flex.1-alpha" +meta['modelspec.description'] = "Flex.1-alpha" + + +os.makedirs(os.path.dirname(flux_path), exist_ok=True) + +print(f"Saving to {flux_path}") + +safetensors.torch.save_file(quantized_state_dict, flux_path, metadata=meta) + +print("Done.") diff --git a/ai-toolkit/scripts/convert_diffusers_to_comfy_transformer_only.py b/ai-toolkit/scripts/convert_diffusers_to_comfy_transformer_only.py new file mode 100644 index 0000000000000000000000000000000000000000..9973c0878d0cc574cbd0dfd1554ed7c1f7b1c874 --- /dev/null +++ b/ai-toolkit/scripts/convert_diffusers_to_comfy_transformer_only.py @@ -0,0 +1,457 @@ +####################################################### +# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file +# This will only have the transformer weights, not the TEs and VAE +# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag +# You can also save with scaled 8-bit using the --do_8bit_scaled flag +# +# Call like this for 8-bit transformer weights with stochastic rounding: +# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit +# +# Call like this for 8-bit transformer weights with scaling: +# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled +# +# Call like this for bf16 transformer weights: +# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors +# +# Output should go in ComfyUI/models/diffusion_models/ +# +####################################################### + + +import argparse +from datetime import date +import json +import os +from pathlib import Path +import safetensors +import safetensors.torch +import torch +import tqdm +from collections import OrderedDict + + +parser = argparse.ArgumentParser() + +parser.add_argument("diffusers_path", type=str, + help="Path to the original Flux diffusers folder.") +parser.add_argument("flux_path", type=str, + help="Output path for the Flux safetensors file.") +parser.add_argument("--do_8_bit", action="store_true", + help="Use 8-bit weights with stochastic rounding instead of bf16.") +parser.add_argument("--do_8bit_scaled", action="store_true", + help="Use scaled 8-bit weights instead of bf16.") +args = parser.parse_args() + +flux_path = Path(args.flux_path) +diffusers_path = Path(args.diffusers_path) + +if os.path.exists(os.path.join(diffusers_path, "transformer")): + diffusers_path = Path(os.path.join(diffusers_path, "transformer")) + +do_8_bit = args.do_8_bit +do_8bit_scaled = args.do_8bit_scaled + +# Don't allow both flags to be active simultaneously +if do_8_bit and do_8bit_scaled: + print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.") + exit() + +if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + +if not diffusers_path.exists(): + print(f"Error: Missing transformer folder: {diffusers_path}") + exit() + +original_json_path = Path.joinpath( + diffusers_path, "diffusion_pytorch_model.safetensors.index.json") + +if not original_json_path.exists(): + print(f"Error: Missing transformer index json: {original_json_path}") + exit() + +with open(original_json_path, "r", encoding="utf-8") as f: + original_json = json.load(f) + +diffusers_map = { + "time_in.in_layer.weight": [ + "time_text_embed.timestep_embedder.linear_1.weight", + ], + "time_in.in_layer.bias": [ + "time_text_embed.timestep_embedder.linear_1.bias", + ], + "time_in.out_layer.weight": [ + "time_text_embed.timestep_embedder.linear_2.weight", + ], + "time_in.out_layer.bias": [ + "time_text_embed.timestep_embedder.linear_2.bias", + ], + "vector_in.in_layer.weight": [ + "time_text_embed.text_embedder.linear_1.weight", + ], + "vector_in.in_layer.bias": [ + "time_text_embed.text_embedder.linear_1.bias", + ], + "vector_in.out_layer.weight": [ + "time_text_embed.text_embedder.linear_2.weight", + ], + "vector_in.out_layer.bias": [ + "time_text_embed.text_embedder.linear_2.bias", + ], + "guidance_in.in_layer.weight": [ + "time_text_embed.guidance_embedder.linear_1.weight", + ], + "guidance_in.in_layer.bias": [ + "time_text_embed.guidance_embedder.linear_1.bias", + ], + "guidance_in.out_layer.weight": [ + "time_text_embed.guidance_embedder.linear_2.weight", + ], + "guidance_in.out_layer.bias": [ + "time_text_embed.guidance_embedder.linear_2.bias", + ], + "txt_in.weight": [ + "context_embedder.weight", + ], + "txt_in.bias": [ + "context_embedder.bias", + ], + "img_in.weight": [ + "x_embedder.weight", + ], + "img_in.bias": [ + "x_embedder.bias", + ], + "double_blocks.().img_mod.lin.weight": [ + "norm1.linear.weight", + ], + "double_blocks.().img_mod.lin.bias": [ + "norm1.linear.bias", + ], + "double_blocks.().txt_mod.lin.weight": [ + "norm1_context.linear.weight", + ], + "double_blocks.().txt_mod.lin.bias": [ + "norm1_context.linear.bias", + ], + "double_blocks.().img_attn.qkv.weight": [ + "attn.to_q.weight", + "attn.to_k.weight", + "attn.to_v.weight", + ], + "double_blocks.().img_attn.qkv.bias": [ + "attn.to_q.bias", + "attn.to_k.bias", + "attn.to_v.bias", + ], + "double_blocks.().txt_attn.qkv.weight": [ + "attn.add_q_proj.weight", + "attn.add_k_proj.weight", + "attn.add_v_proj.weight", + ], + "double_blocks.().txt_attn.qkv.bias": [ + "attn.add_q_proj.bias", + "attn.add_k_proj.bias", + "attn.add_v_proj.bias", + ], + "double_blocks.().img_attn.norm.query_norm.scale": [ + "attn.norm_q.weight", + ], + "double_blocks.().img_attn.norm.key_norm.scale": [ + "attn.norm_k.weight", + ], + "double_blocks.().txt_attn.norm.query_norm.scale": [ + "attn.norm_added_q.weight", + ], + "double_blocks.().txt_attn.norm.key_norm.scale": [ + "attn.norm_added_k.weight", + ], + "double_blocks.().img_mlp.0.weight": [ + "ff.net.0.proj.weight", + ], + "double_blocks.().img_mlp.0.bias": [ + "ff.net.0.proj.bias", + ], + "double_blocks.().img_mlp.2.weight": [ + "ff.net.2.weight", + ], + "double_blocks.().img_mlp.2.bias": [ + "ff.net.2.bias", + ], + "double_blocks.().txt_mlp.0.weight": [ + "ff_context.net.0.proj.weight", + ], + "double_blocks.().txt_mlp.0.bias": [ + "ff_context.net.0.proj.bias", + ], + "double_blocks.().txt_mlp.2.weight": [ + "ff_context.net.2.weight", + ], + "double_blocks.().txt_mlp.2.bias": [ + "ff_context.net.2.bias", + ], + "double_blocks.().img_attn.proj.weight": [ + "attn.to_out.0.weight", + ], + "double_blocks.().img_attn.proj.bias": [ + "attn.to_out.0.bias", + ], + "double_blocks.().txt_attn.proj.weight": [ + "attn.to_add_out.weight", + ], + "double_blocks.().txt_attn.proj.bias": [ + "attn.to_add_out.bias", + ], + "single_blocks.().modulation.lin.weight": [ + "norm.linear.weight", + ], + "single_blocks.().modulation.lin.bias": [ + "norm.linear.bias", + ], + "single_blocks.().linear1.weight": [ + "attn.to_q.weight", + "attn.to_k.weight", + "attn.to_v.weight", + "proj_mlp.weight", + ], + "single_blocks.().linear1.bias": [ + "attn.to_q.bias", + "attn.to_k.bias", + "attn.to_v.bias", + "proj_mlp.bias", + ], + "single_blocks.().linear2.weight": [ + "proj_out.weight", + ], + "single_blocks.().norm.query_norm.scale": [ + "attn.norm_q.weight", + ], + "single_blocks.().norm.key_norm.scale": [ + "attn.norm_k.weight", + ], + "single_blocks.().linear2.weight": [ + "proj_out.weight", + ], + "single_blocks.().linear2.bias": [ + "proj_out.bias", + ], + "final_layer.linear.weight": [ + "proj_out.weight", + ], + "final_layer.linear.bias": [ + "proj_out.bias", + ], + "final_layer.adaLN_modulation.1.weight": [ + "norm_out.linear.weight", + ], + "final_layer.adaLN_modulation.1.bias": [ + "norm_out.linear.bias", + ], +} + + +def is_in_diffusers_map(k): + for values in diffusers_map.values(): + for value in values: + if k.endswith(value): + return True + return False + + +diffusers = {k: Path.joinpath(diffusers_path, v) + for k, v in original_json["weight_map"].items() if is_in_diffusers_map(k)} + +original_safetensors = set(diffusers.values()) + +# determine the number of transformer blocks +transformer_blocks = 0 +single_transformer_blocks = 0 +for key in diffusers.keys(): + print(key) + if key.startswith("transformer_blocks."): + print(key) + block = int(key.split(".")[1]) + if block >= transformer_blocks: + transformer_blocks = block + 1 + elif key.startswith("single_transformer_blocks."): + block = int(key.split(".")[1]) + if block >= single_transformer_blocks: + single_transformer_blocks = block + 1 + +print(f"Transformer blocks: {transformer_blocks}") +print(f"Single transformer blocks: {single_transformer_blocks}") + +for file in original_safetensors: + if not file.exists(): + print(f"Error: Missing transformer safetensors file: {file}") + exit() + +original_safetensors = {f: safetensors.safe_open( + f, framework="pt", device="cpu") for f in original_safetensors} + + +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +flux_values = {} + +for b in range(transformer_blocks): + for key, weights in diffusers_map.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + found = True + for weight in weights: + if not (f"{block_prefix}{weight}" in diffusers): + found = False + if found: + flux_values[key.replace("()", f"{b}")] = [ + f"{block_prefix}{weight}" for weight in weights] +for b in range(single_transformer_blocks): + for key, weights in diffusers_map.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + found = True + for weight in weights: + if not (f"{block_prefix}{weight}" in diffusers): + found = False + if found: + flux_values[key.replace("()", f"{b}")] = [ + f"{block_prefix}{weight}" for weight in weights] + +for key, weights in diffusers_map.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + found = True + for weight in weights: + if not (f"{weight}" in diffusers): + found = False + if found: + flux_values[key] = [f"{weight}" for weight in weights] + +flux = {} + +for key, values in tqdm.tqdm(flux_values.items()): + if len(values) == 1: + flux[key] = original_safetensors[diffusers[values[0]] + ].get_tensor(values[0]).to("cpu") + else: + flux[key] = torch.cat( + [ + original_safetensors[diffusers[value] + ].get_tensor(value).to("cpu") + for value in values + ] + ) + +if "norm_out.linear.weight" in diffusers: + flux["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + original_safetensors[diffusers["norm_out.linear.weight"]].get_tensor( + "norm_out.linear.weight").to("cpu") + ) +if "norm_out.linear.bias" in diffusers: + flux["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + original_safetensors[diffusers["norm_out.linear.bias"]].get_tensor( + "norm_out.linear.bias").to("cpu") + ) + + +def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn): + # Define the float8 range + min_val = torch.finfo(dtype).min + max_val = torch.finfo(dtype).max + + # Clip values to float8 range + tensor = torch.clamp(tensor, min_val, max_val) + + # Convert to float32 for calculations + tensor = tensor.float() + + # Get the nearest representable float8 values + lower = torch.floor(tensor * 256) / 256 + upper = torch.ceil(tensor * 256) / 256 + + # Calculate the probability of rounding up + prob = (tensor - lower) / (upper - lower) + + # Generate random values for stochastic rounding + rand = torch.rand_like(tensor) + + # Perform stochastic rounding + rounded = torch.where(rand < prob, upper, lower) + + # Convert back to float8 + return rounded.to(dtype) + + +# List of keys that should not be scaled (usually embedding layers and biases) +blacklist = [] +for key in flux.keys(): + if not key.endswith(".weight") or "embed" in key: + blacklist.append(key) + +# Function to scale weights for 8-bit quantization +def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn): + # Get the limits of the dtype + min_val = torch.finfo(dtype).min + max_val = torch.finfo(dtype).max + + # Only process 2D tensors that are not in the blacklist + if tensor.dim() == 2: + # Calculate the scaling factor + abs_max = torch.max(torch.abs(tensor)) + scale = abs_max / max_value + + # Scale the tensor and clip to float8 range + scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype) + + return scaled_tensor, scale + else: + # For tensors that shouldn't be scaled, just convert to float8 + return tensor.clip(min=min_val, max=max_val).to(dtype), None + + +# set all the keys to appropriate dtype +if do_8_bit: + print("Converting to 8-bit with stochastic rounding...") + for key in flux.keys(): + flux[key] = stochastic_round_to( + flux[key], torch.float8_e4m3fn).to('cpu') +elif do_8bit_scaled: + print("Converting to scaled 8-bit...") + scales = {} + for key in tqdm.tqdm(flux.keys()): + if key.endswith(".weight") and key not in blacklist: + flux[key], scale = scale_weights_to_8bit(flux[key]) + if scale is not None: + scale_key = key[:-len(".weight")] + ".scale_weight" + scales[scale_key] = scale + else: + # For non-weight tensors or blacklisted ones, just convert without scaling + min_val = torch.finfo(torch.float8_e4m3fn).min + max_val = torch.finfo(torch.float8_e4m3fn).max + flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu') + + # Add all the scales to the flux dictionary + flux.update(scales) + + # Add a marker tensor to indicate this is a scaled fp8 model + flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn) +else: + print("Converting to bfloat16...") + for key in flux.keys(): + flux[key] = flux[key].clone().to('cpu', torch.bfloat16) + +meta = OrderedDict() +meta['format'] = 'pt' +# date format like 2024-08-01 YYYY-MM-DD +meta['modelspec.date'] = date.today().strftime("%Y-%m-%d") + +os.makedirs(os.path.dirname(flux_path), exist_ok=True) + +print(f"Saving to {flux_path}") + +safetensors.torch.save_file(flux, flux_path, metadata=meta) + +print("Done.") \ No newline at end of file diff --git a/ai-toolkit/scripts/convert_lora_to_peft_format.py b/ai-toolkit/scripts/convert_lora_to_peft_format.py new file mode 100644 index 0000000000000000000000000000000000000000..3034db646ce0cbf784940df17a45e2468063f485 --- /dev/null +++ b/ai-toolkit/scripts/convert_lora_to_peft_format.py @@ -0,0 +1,91 @@ +# currently only works with flux as support is not quite there yet + +import argparse +import os.path +from collections import OrderedDict + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +args = parser.parse_args() +args.input_path = os.path.abspath(args.input_path) +args.output_path = os.path.abspath(args.output_path) + +from safetensors.torch import load_file, save_file + +meta = OrderedDict() +meta['format'] = 'pt' + +state_dict = load_file(args.input_path) + +# peft doesnt have an alpha so we need to scale the weights +alpha_keys = [ + 'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux +] + +# keys where the rank is in the first dimension +rank_idx0_keys = [ + 'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight' + # 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight' +] + +alpha = None +rank = None + +for key in rank_idx0_keys: + if key in state_dict: + rank = int(state_dict[key].shape[0]) + break + +if rank is None: + raise ValueError(f'Could not find rank in state dict') + +for key in alpha_keys: + if key in state_dict: + alpha = int(state_dict[key]) + break + +if alpha is None: + # set to rank if not found + alpha = rank + + +up_multiplier = alpha / rank + +new_state_dict = {} + +for key, value in state_dict.items(): + if key.endswith('.alpha'): + continue + + orig_dtype = value.dtype + + new_val = value.float() * up_multiplier + + new_key = key + new_key = new_key.replace('lora_transformer_', 'transformer.') + for i in range(100): + new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.') + new_key = new_key.replace('lora_down', 'lora_A') + new_key = new_key.replace('lora_up', 'lora_B') + new_key = new_key.replace('_lora', '.lora') + new_key = new_key.replace('attn_', 'attn.') + new_key = new_key.replace('ff_', 'ff.') + new_key = new_key.replace('context_net_', 'context.net.') + new_key = new_key.replace('0_proj', '0.proj') + new_key = new_key.replace('norm_linear', 'norm.linear') + new_key = new_key.replace('norm_out_linear', 'norm_out.linear') + new_key = new_key.replace('to_out_', 'to_out.') + + new_state_dict[new_key] = new_val.to(orig_dtype) + +save_file(new_state_dict, args.output_path, meta) +print(f'Saved to {args.output_path}') diff --git a/ai-toolkit/scripts/extract_lora_from_flex.py b/ai-toolkit/scripts/extract_lora_from_flex.py new file mode 100644 index 0000000000000000000000000000000000000000..c80c889221164473d6e74202f27917e9714c1f62 --- /dev/null +++ b/ai-toolkit/scripts/extract_lora_from_flex.py @@ -0,0 +1,245 @@ +import os +from tqdm import tqdm +import argparse +from collections import OrderedDict + +parser = argparse.ArgumentParser(description="Extract LoRA from Flex") +parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path") +parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") +parser.add_argument("--output", type=str, required=True, help="Output path for lora") +parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") +parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") +parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks") + +args = parser.parse_args() + +if True: + # set cuda environment variable + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + import torch + from safetensors.torch import load_file, save_file + from lycoris.utils import extract_linear, extract_conv, make_sparse + from diffusers import FluxTransformer2DModel + +base = args.base +tuned = args.tuned +output_path = args.output +dim = args.rank + +os.makedirs(os.path.dirname(output_path), exist_ok=True) + +state_dict_base = {} +state_dict_tuned = {} + +output_dict = {} + +@torch.no_grad() +def extract_diff( + base_unet, + db_unet, + mode="fixed", + linear_mode_param=0, + conv_mode_param=0, + extract_device="cpu", + use_bias=False, + sparsity=0.98, + # small_conv=True, + small_conv=False, +): + UNET_TARGET_REPLACE_MODULE = [ + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + "LoRACompatibleLinear", + "LoRACompatibleConv" + ] + LORA_PREFIX_UNET = "transformer" + + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + ): + loras = {} + temp = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = module + + for name, module in tqdm( + list((n, m) for n, m in target_module.named_modules() if n in temp) + ): + weights = temp[name] + lora_name = prefix + "." + name + # lora_name = lora_name.replace(".", "_") + layer = module.__class__.__name__ + if 'transformer_blocks' not in lora_name and not args.full: + continue + + if layer in { + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + "Embedding", + "LoRACompatibleLinear", + "LoRACompatibleConv" + }: + root_weight = module.weight + try: + if torch.allclose(root_weight, weights.weight): + continue + except: + continue + else: + continue + module = module.to(extract_device, torch.float32) + weights = weights.to(extract_device, torch.float32) + + if mode == "full": + decompose_mode = "full" + elif layer == "Linear": + weight, decompose_mode = extract_linear( + (root_weight - weights.weight), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + elif layer == "Conv2d": + is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 + weight, decompose_mode = extract_conv( + (root_weight - weights.weight), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == "low rank": + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + "fixed", + dim, + extract_device, + True, + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f"{lora_name}.lora_mid.weight"] = ( + extract_c.detach().cpu().contiguous().half() + ) + diff = ( + ( + root_weight + - torch.einsum( + "i j k l, j r, p i -> p r k l", + extract_c, + extract_a.flatten(1, -1), + extract_b.flatten(1, -1), + ) + ) + .detach() + .cpu() + .contiguous() + ) + del extract_c + else: + module = module.to("cpu") + weights = weights.to("cpu") + continue + + if decompose_mode == "low rank": + loras[f"{lora_name}.lora_A.weight"] = ( + extract_a.detach().cpu().contiguous().half() + ) + loras[f"{lora_name}.lora_B.weight"] = ( + extract_b.detach().cpu().contiguous().half() + ) + # loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f"{lora_name}.bias_indices"] = indices + loras[f"{lora_name}.bias_values"] = values + loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( + torch.int16 + ) + del extract_a, extract_b, diff + elif decompose_mode == "full": + if "Norm" in layer: + w_key = "w_norm" + b_key = "b_norm" + else: + w_key = "diff" + b_key = "diff_b" + weight_diff = module.weight - weights.weight + loras[f"{lora_name}.{w_key}"] = ( + weight_diff.detach().cpu().contiguous().half() + ) + if getattr(weights, "bias", None) is not None: + bias_diff = module.bias - weights.bias + loras[f"{lora_name}.{b_key}"] = ( + bias_diff.detach().cpu().contiguous().half() + ) + else: + raise NotImplementedError + module = module.to("cpu", torch.bfloat16) + weights = weights.to("cpu", torch.bfloat16) + return loras + + all_loras = {} + + all_loras |= make_state_dict( + LORA_PREFIX_UNET, + base_unet, + db_unet, + UNET_TARGET_REPLACE_MODULE, + ) + del base_unet, db_unet + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + all_lora_name = set() + for k in all_loras: + lora_name, weight = k.rsplit(".", 1) + all_lora_name.add(lora_name) + print(len(all_lora_name)) + return all_loras + + +# find all the .safetensors files and load them +print("Loading Base") +base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16) + +print("Loading Tuned") +tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16) + +output_dict = extract_diff( + base_model, + tuned_model, + mode="fixed", + linear_mode_param=dim, + conv_mode_param=dim, + extract_device="cuda", + use_bias=False, + sparsity=0.98, + small_conv=False, +) + +meta = OrderedDict() +meta['format'] = 'pt' + +save_file(output_dict, output_path, metadata=meta) + +print("Done") diff --git a/ai-toolkit/scripts/generate_sampler_step_scales.py b/ai-toolkit/scripts/generate_sampler_step_scales.py new file mode 100644 index 0000000000000000000000000000000000000000..11efb3183becb48ec4a485565d53049fb6a8d11c --- /dev/null +++ b/ai-toolkit/scripts/generate_sampler_step_scales.py @@ -0,0 +1,20 @@ +import argparse +import torch +import os +from diffusers import StableDiffusionPipeline +import sys + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# add project root to path +sys.path.append(PROJECT_ROOT) + +SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales') + + +parser = argparse.ArgumentParser(description='Process some images.') +add_arg = parser.add_argument +add_arg('--model', type=str, required=True, help='Path to model') +add_arg('--sampler', type=str, required=True, help='Name of sampler') + +args = parser.parse_args() + diff --git a/ai-toolkit/scripts/make_diffusers_model.py b/ai-toolkit/scripts/make_diffusers_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4536a9215540dd01321ef1426665db9d6ef6347f --- /dev/null +++ b/ai-toolkit/scripts/make_diffusers_model.py @@ -0,0 +1,61 @@ +import argparse +from collections import OrderedDict +import sys +import os +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(ROOT_DIR) + +import torch + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() +device = torch.device('cpu') +dtype = torch.float32 + +print(f"Loading model from {args.input_path}") + + +diffusers_model_config = ModelConfig( + name_or_path=args.input_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + dtype=dtype, + ) +diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, +) +diffusers_sd.load_model() + + +print(f"Loaded model from {args.input_path}") + +diffusers_sd.pipeline.fuse_lora() + +meta = OrderedDict() + +diffusers_sd.save(args.output_path, meta=meta) + + +print(f"Saved to {args.output_path}") diff --git a/ai-toolkit/scripts/make_lcm_sdxl_model.py b/ai-toolkit/scripts/make_lcm_sdxl_model.py new file mode 100644 index 0000000000000000000000000000000000000000..20e95ce795a39fe2837b80fcbf1950c256ad4c59 --- /dev/null +++ b/ai-toolkit/scripts/make_lcm_sdxl_model.py @@ -0,0 +1,67 @@ +import argparse +from collections import OrderedDict + +import torch + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() +device = torch.device('cpu') +dtype = torch.float32 + +print(f"Loading model from {args.input_path}") + +if args.sdxl: + adapter_id = "latent-consistency/lcm-lora-sdxl" +if args.refiner: + adapter_id = "latent-consistency/lcm-lora-sdxl" +elif args.ssd: + adapter_id = "latent-consistency/lcm-lora-ssd-1b" +else: + adapter_id = "latent-consistency/lcm-lora-sdv1-5" + + +diffusers_model_config = ModelConfig( + name_or_path=args.input_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + dtype=dtype, + ) +diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, +) +diffusers_sd.load_model() + + +print(f"Loaded model from {args.input_path}") + +diffusers_sd.pipeline.load_lora_weights(adapter_id) +diffusers_sd.pipeline.fuse_lora() + +meta = OrderedDict() + +diffusers_sd.save(args.output_path, meta=meta) + + +print(f"Saved to {args.output_path}") diff --git a/ai-toolkit/scripts/patch_te_adapter.py b/ai-toolkit/scripts/patch_te_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7249a46d8e566c3889538c465359e6c66b1c9602 --- /dev/null +++ b/ai-toolkit/scripts/patch_te_adapter.py @@ -0,0 +1,42 @@ +import torch +from safetensors.torch import save_file, load_file +from collections import OrderedDict +meta = OrderedDict() +meta["format"] ="pt" + +attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors") +state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors") + +attn_list = [] +for key, value in state_dict.items(): + if "attn1" in key: + attn_list.append(key) + +attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'] + +adapter_names = [] +for i in range(100): + if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict: + adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter") + + +for i in range(len(adapter_names)): + adapter_name = adapter_names[i] + attn_name = attn_names[i] + adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight' + adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight' + state_k_name = attn_name.replace(".processor", ".to_k.weight") + state_v_name = attn_name.replace(".processor", ".to_v.weight") + if adapter_k_name in attn_dict: + state_dict[state_k_name] = attn_dict[adapter_k_name] + state_dict[state_v_name] = attn_dict[adapter_v_name] + else: + print("adapter_k_name", adapter_k_name) + print("state_k_name", state_k_name) + +for key, value in state_dict.items(): + state_dict[key] = value.cpu().to(torch.float16) + +save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta) + +print("Done") diff --git a/ai-toolkit/scripts/repair_dataset_folder.py b/ai-toolkit/scripts/repair_dataset_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9d277508c19046b5737620a01b9eba09635e98 --- /dev/null +++ b/ai-toolkit/scripts/repair_dataset_folder.py @@ -0,0 +1,65 @@ +import argparse +from PIL import Image +from PIL.ImageOps import exif_transpose +from tqdm import tqdm +import os + +parser = argparse.ArgumentParser(description='Process some images.') +parser.add_argument("input_folder", type=str, help="Path to folder containing images") + +args = parser.parse_args() + +img_types = ['.jpg', '.jpeg', '.png', '.webp'] + +# find all images in the input folder +images = [] +for root, _, files in os.walk(args.input_folder): + for file in files: + if file.lower().endswith(tuple(img_types)): + images.append(os.path.join(root, file)) +print(f"Found {len(images)} images") + +num_skipped = 0 +num_repaired = 0 +num_deleted = 0 + +pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image") +for img_path in images: + filename = os.path.basename(img_path) + filename_no_ext, file_extension = os.path.splitext(filename) + # if it is jpg, ignore + if file_extension.lower() == '.jpg': + num_skipped += 1 + pbar.update(1) + + continue + + try: + img = Image.open(img_path) + except Exception as e: + print(f"Error opening {img_path}: {e}") + # delete it + os.remove(img_path) + num_deleted += 1 + pbar.update(1) + pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") + continue + + + try: + img = exif_transpose(img) + except Exception as e: + print(f"Error rotating {img_path}: {e}") + + new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg') + + img = img.convert("RGB") + img.save(new_path, quality=95) + # remove the old file + os.remove(img_path) + num_repaired += 1 + pbar.update(1) + # update pbar + pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") + +print("Done") \ No newline at end of file diff --git a/ai-toolkit/scripts/update_sponsors.py b/ai-toolkit/scripts/update_sponsors.py new file mode 100644 index 0000000000000000000000000000000000000000..d80b4bc180dee3cb4ad312b98f3fb1d509dde8d3 --- /dev/null +++ b/ai-toolkit/scripts/update_sponsors.py @@ -0,0 +1,309 @@ +import os +import requests +import json +from datetime import datetime +from dotenv import load_dotenv + +# Load environment variables from .env file +env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") +load_dotenv(dotenv_path=env_path) + +# API credentials +PATREON_TOKEN = os.getenv("PATREON_ACCESS_TOKEN") +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +GITHUB_USERNAME = os.getenv("GITHUB_USERNAME") +GITHUB_ORG = os.getenv("GITHUB_ORG") # Organization name (optional) + +# Output file +README_PATH = "SUPPORTERS.md" + +def fetch_patreon_supporters(): + """Fetch current Patreon supporters""" + print("Fetching Patreon supporters...") + + headers = { + "Authorization": f"Bearer {PATREON_TOKEN}", + "Content-Type": "application/json" + } + + url = "https://www.patreon.com/api/oauth2/v2/campaigns" + + try: + # First get the campaign ID + campaign_response = requests.get(url, headers=headers) + campaign_response.raise_for_status() + campaign_data = campaign_response.json() + + if not campaign_data.get('data'): + print("No campaigns found for this Patreon account") + return [] + + campaign_id = campaign_data['data'][0]['id'] + + # Now get the supporters for this campaign + members_url = f"https://www.patreon.com/api/oauth2/v2/campaigns/{campaign_id}/members" + params = { + "include": "user", + "fields[member]": "full_name,is_follower,patron_status", # Removed profile_url + "fields[user]": "image_url" + } + + supporters = [] + while members_url: + members_response = requests.get(members_url, headers=headers, params=params) + members_response.raise_for_status() + members_data = members_response.json() + + # Process the response to extract active patrons + for member in members_data.get('data', []): + attributes = member.get('attributes', {}) + + # Only include active patrons + if attributes.get('patron_status') == 'active_patron': + name = attributes.get('full_name', 'Anonymous Supporter') + + # Get user data which contains the profile image + user_id = member.get('relationships', {}).get('user', {}).get('data', {}).get('id') + profile_image = None + profile_url = None # Removed profile_url since it's not supported + + if user_id: + for included in members_data.get('included', []): + if included.get('id') == user_id and included.get('type') == 'user': + profile_image = included.get('attributes', {}).get('image_url') + break + + supporters.append({ + 'name': name, + 'profile_image': profile_image, + 'profile_url': profile_url, # This will be None + 'platform': 'Patreon', + 'amount': 0 # Placeholder, as Patreon API doesn't provide this in the current response + }) + + # Handle pagination + members_url = members_data.get('links', {}).get('next') + + print(f"Found {len(supporters)} active Patreon supporters") + return supporters + + except requests.exceptions.RequestException as e: + print(f"Error fetching Patreon data: {e}") + print(f"Response content: {e.response.content if hasattr(e, 'response') else 'No response content'}") + return [] + +def fetch_github_sponsors(): + """Fetch current GitHub sponsors for a user or organization""" + print("Fetching GitHub sponsors...") + + headers = { + "Authorization": f"Bearer {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json" + } + + # Determine if we're fetching for a user or an organization + entity_type = "organization" if GITHUB_ORG else "user" + entity_name = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME + + if not entity_name: + print("Error: Neither GITHUB_USERNAME nor GITHUB_ORG is set") + return [] + + # Different GraphQL query structure based on entity type + if entity_type == "user": + query = """ + query { + user(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + nodes { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + monthlyPriceInDollars + } + isOneTimePayment + isActive + } + } + } + } + """ % entity_name + else: # organization + query = """ + query { + organization(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + nodes { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + monthlyPriceInDollars + } + isOneTimePayment + isActive + } + } + } + } + """ % entity_name + + try: + response = requests.post( + "https://api.github.com/graphql", + headers=headers, + json={"query": query} + ) + response.raise_for_status() + data = response.json() + + # Process the response - the path to the data differs based on entity type + if entity_type == "user": + sponsors_data = data.get('data', {}).get('user', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', []) + else: + sponsors_data = data.get('data', {}).get('organization', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', []) + + sponsors = [] + for sponsor in sponsors_data: + # Only include active sponsors + if sponsor.get('isActive'): + entity = sponsor.get('sponsorEntity', {}) + name = entity.get('name') or entity.get('login', 'Anonymous Sponsor') + profile_image = entity.get('avatarUrl') + profile_url = entity.get('url') + amount = sponsor.get('tier', {}).get('monthlyPriceInDollars', 0) + + sponsors.append({ + 'name': name, + 'profile_image': profile_image, + 'profile_url': profile_url, + 'platform': 'GitHub Sponsors', + 'amount': amount + }) + + print(f"Found {len(sponsors)} active GitHub sponsors for {entity_type} '{entity_name}'") + return sponsors + + except requests.exceptions.RequestException as e: + print(f"Error fetching GitHub sponsors data: {e}") + return [] + +def generate_readme(supporters): + """Generate a README.md file with supporter information""" + print(f"Generating {README_PATH}...") + + # Sort supporters by amount (descending) and then by name + supporters.sort(key=lambda x: (-x['amount'], x['name'].lower())) + + # Determine the proper footer links based on what's configured + github_entity = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME + github_entity_type = "orgs" if GITHUB_ORG else "sponsors" + github_sponsor_url = f"https://github.com/{github_entity_type}/{github_entity}" + + with open(README_PATH, "w", encoding="utf-8") as f: + f.write("## Support My Work\n\n") + f.write("If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! \n\n") + # Create appropriate call-to-action based on what's configured + cta_parts = [] + if github_entity: + cta_parts.append(f"[Become a sponsor on GitHub]({github_sponsor_url})") + if PATREON_TOKEN: + cta_parts.append("[support me on Patreon](https://www.patreon.com/ostris)") + + if cta_parts: + if GITHUB_ORG: + f.write(f"{' or '.join(cta_parts)}.\n\n") + f.write("Thank you to all my current supporters!\n\n") + + f.write(f"_Last updated: {datetime.now().strftime('%Y-%m-%d')}_\n\n") + + # Write GitHub Sponsors section + github_sponsors = [s for s in supporters if s['platform'] == 'GitHub Sponsors'] + if github_sponsors: + f.write("### GitHub Sponsors\n\n") + for sponsor in github_sponsors: + if sponsor['profile_image']: + f.write(f"\"{sponsor['name']}\" ") + else: + f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ") + f.write("\n\n") + + # Write Patreon section + patreon_supporters = [s for s in supporters if s['platform'] == 'Patreon'] + if patreon_supporters: + f.write("### Patreon Supporters\n\n") + for supporter in patreon_supporters: + if supporter['profile_image']: + f.write(f"\"{supporter['name']}\" ") + else: + f.write(f"[{supporter['name']}]({supporter['profile_url']}) ") + f.write("\n\n") + + f.write("\n---\n\n") + + + print(f"Successfully generated {README_PATH} with {len(supporters)} supporters!") + +def main(): + """Main function""" + print("Starting supporter data collection...") + + # Check if required environment variables are set + missing_vars = [] + if not GITHUB_TOKEN: + missing_vars.append("GITHUB_TOKEN") + + # Either username or org is required for GitHub + if not GITHUB_USERNAME and not GITHUB_ORG: + missing_vars.append("GITHUB_USERNAME or GITHUB_ORG") + + # Patreon token is optional but warn if missing + patreon_enabled = bool(PATREON_TOKEN) + + if missing_vars: + print(f"Error: Missing required environment variables: {', '.join(missing_vars)}") + print("Please add them to your .env file") + return + + if not patreon_enabled: + print("Warning: PATREON_ACCESS_TOKEN not set. Will only fetch GitHub sponsors.") + + # Fetch data from both platforms + patreon_supporters = fetch_patreon_supporters() if PATREON_TOKEN else [] + github_sponsors = fetch_github_sponsors() + + # Combine supporters from both platforms + all_supporters = patreon_supporters + github_sponsors + + if not all_supporters: + print("No supporters found on either platform") + return + + # Generate README + generate_readme(all_supporters) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ai-toolkit/testing/compare_keys.py b/ai-toolkit/testing/compare_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4f95203fe1024daeb66ddd79696875f04578c7 --- /dev/null +++ b/ai-toolkit/testing/compare_keys.py @@ -0,0 +1,99 @@ +import argparse +import os + +import torch +from diffusers.loaders import LoraLoaderMixin +from safetensors.torch import load_file +from collections import OrderedDict +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path to first safe tensor file' +) + +parser.add_argument( + 'file_2', + nargs='+', + type=str, + help='Path to second safe tensor file' +) + +args = parser.parse_args() + +find_matches = False + +state_dict_file_1 = load_file(args.file_1[0]) +state_dict_1_keys = list(state_dict_file_1.keys()) + +state_dict_file_2 = load_file(args.file_2[0]) +state_dict_2_keys = list(state_dict_file_2.keys()) +keys_in_both = [] + +keys_not_in_state_dict_2 = [] +for key in state_dict_1_keys: + if key not in state_dict_2_keys: + keys_not_in_state_dict_2.append(key) + +keys_not_in_state_dict_1 = [] +for key in state_dict_2_keys: + if key not in state_dict_1_keys: + keys_not_in_state_dict_1.append(key) + +keys_in_both = [] +for key in state_dict_1_keys: + if key in state_dict_2_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_state_dict_2.sort() +keys_not_in_state_dict_1.sort() +keys_in_both.sort() + + +json_data = { + "both": keys_in_both, + "not_in_state_dict_2": keys_not_in_state_dict_2, + "not_in_state_dict_1": keys_not_in_state_dict_1 +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_state_dict_1: + remaining_diffusers_values[key] = state_dict_file_2[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_state_dict_2: + remaining_ldm_values[key] = state_dict_file_1[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') +state_dict_1_filename = os.path.basename(args.file_1[0]) +state_dict_2_filename = os.path.basename(args.file_2[0]) +# save key names for each in own file +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_1_keys, indent=4)) + +with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_2_keys, indent=4)) + + +with open(json_save_path, 'w') as f: + f.write(json_data) \ No newline at end of file diff --git a/ai-toolkit/testing/generate_lora_mapping.py b/ai-toolkit/testing/generate_lora_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..e632d2a662f6a6498a8d340074ea9c9a27ac431a --- /dev/null +++ b/ai-toolkit/testing/generate_lora_mapping.py @@ -0,0 +1,130 @@ +from collections import OrderedDict + +import torch +from safetensors.torch import load_file +import argparse +import os +import json + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json') + +# load keymap +with open(keymap_path, 'r') as f: + keymap = json.load(f) + +lora_keymap = OrderedDict() + +# convert keymap to lora key naming +for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items(): + if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'): + # skip it + continue + # sdxl has same te for locon with kohya and ours + if ldm_key.startswith('conditioner'): + #skip it + continue + # ignore vae + if ldm_key.startswith('first_stage_model'): + continue + ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_') + ldm_key = ldm_key.replace('.weight', '') + ldm_key = ldm_key.replace('.', '_') + + diffusers_key = diffusers_key.replace('unet_', 'lora_unet_') + diffusers_key = diffusers_key.replace('.weight', '') + diffusers_key = diffusers_key.replace('.', '_') + + lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha" + lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight" + lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight" + + +parser = argparse.ArgumentParser() +parser.add_argument("input", help="input file") +parser.add_argument("input2", help="input2 file") + +args = parser.parse_args() + +# name = args.name +# if args.sdxl: +# name += '_sdxl' +# elif args.sd2: +# name += '_sd2' +# else: +# name += '_sd1' +name = 'stable_diffusion_locon_sdxl' + +locon_save = load_file(args.input) +our_save = load_file(args.input2) + +our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys())) +locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys())) + +print(f"we have {len(our_extra_keys)} extra keys") +print(f"locon has {len(locon_extra_keys)} extra keys") + +save_dtype = torch.float16 +print(f"our extra keys: {our_extra_keys}") +print(f"locon extra keys: {locon_extra_keys}") + + +def export_state_dict(our_save): + converted_state_dict = OrderedDict() + for key, value in our_save.items(): + # test encoders share keys for some reason + if key.startswith('lora_te'): + converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) + else: + converted_key = key + for ldm_key, diffusers_key in lora_keymap.items(): + if converted_key == diffusers_key: + converted_key = ldm_key + + converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype) + return converted_state_dict + +def import_state_dict(loaded_state_dict): + converted_state_dict = OrderedDict() + for key, value in loaded_state_dict.items(): + if key.startswith('lora_te'): + converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) + else: + converted_key = key + for ldm_key, diffusers_key in lora_keymap.items(): + if converted_key == ldm_key: + converted_key = diffusers_key + + converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype) + return converted_state_dict + + +# check it again +converted_state_dict = export_state_dict(our_save) +converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys())) +locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys())) + + +print(f"we have {len(converted_extra_keys)} extra keys") +print(f"locon has {len(locon_extra_keys)} extra keys") + +print(f"our extra keys: {converted_extra_keys}") + +# convert back +cycle_state_dict = import_state_dict(converted_state_dict) +cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys())) +our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys())) + +print(f"we have {len(our_extra_keys)} extra keys") +print(f"cycle has {len(cycle_extra_keys)} extra keys") + +# save keymap +to_save = OrderedDict() +to_save['ldm_diffusers_keymap'] = lora_keymap + +with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f: + json.dump(to_save, f, indent=4) + + + diff --git a/ai-toolkit/testing/generate_weight_mappings.py b/ai-toolkit/testing/generate_weight_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..346fe09d5c98a22a3c06ac9ae1dadb549a196193 --- /dev/null +++ b/ai-toolkit/testing/generate_weight_mappings.py @@ -0,0 +1,479 @@ +import argparse +import gc +import os +import re +import os +# add project root to sys path +import sys + +from diffusers import DiffusionPipeline, StableDiffusionXLPipeline + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from diffusers.loaders import LoraLoaderMixin +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import json +from tqdm import tqdm + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + +KEYMAPS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'toolkit', 'keymaps') + +device = torch.device('cpu') +dtype = torch.float32 + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +def get_reduced_shape(shape_tuple): + # iterate though shape anr remove 1s + new_shape = [] + for dim in shape_tuple: + if dim != 1: + new_shape.append(dim) + return tuple(new_shape) + + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path to first safe tensor file' +) + +parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make') +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--vega', action='store_true', help='is vega model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() + +file_path = args.file_1[0] + +find_matches = False + +print(f'Loading diffusers model') + +ignore_ldm_begins_with = [] + +diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1] +if args.ssd: + diffusers_file_path = "segmind/SSD-1B" +if args.vega: + diffusers_file_path = "segmind/Segmind-Vega" + +# if args.refiner: +# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0" + +if not args.refiner: + + diffusers_model_config = ModelConfig( + name_or_path=diffusers_file_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + is_vega=args.vega, + dtype=dtype, + ) + diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, + ) + diffusers_sd.load_model() + # delete things we dont need + del diffusers_sd.tokenizer + flush() + + print(f'Loading ldm model') + diffusers_state_dict = diffusers_sd.state_dict() +else: + # refiner wont work directly with stable diffusion + # so we need to load the model and then load the state dict + diffusers_pipeline = StableDiffusionXLPipeline.from_single_file( + diffusers_file_path, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + ).to(device) + # diffusers_pipeline = StableDiffusionXLPipeline.from_single_file( + # file_path, + # torch_dtype=torch.float16, + # use_safetensors=True, + # variant="fp16", + # ).to(device) + + SD_PREFIX_VAE = "vae" + SD_PREFIX_UNET = "unet" + SD_PREFIX_REFINER_UNET = "refiner_unet" + SD_PREFIX_TEXT_ENCODER = "te" + + SD_PREFIX_TEXT_ENCODER1 = "te0" + SD_PREFIX_TEXT_ENCODER2 = "te1" + + diffusers_state_dict = OrderedDict() + for k, v in diffusers_pipeline.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + diffusers_state_dict[new_key] = v + for k, v in diffusers_pipeline.text_encoder_2.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" + diffusers_state_dict[new_key] = v + for k, v in diffusers_pipeline.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + diffusers_state_dict[new_key] = v + + # add ignore ones as we are only going to focus on unet and copy the rest + # ignore_ldm_begins_with = ["conditioner.", "first_stage_model."] + +diffusers_dict_keys = list(diffusers_state_dict.keys()) + +ldm_state_dict = load_file(file_path) +ldm_dict_keys = list(ldm_state_dict.keys()) + +ldm_diffusers_keymap = OrderedDict() +ldm_diffusers_shape_map = OrderedDict() +ldm_operator_map = OrderedDict() +diffusers_operator_map = OrderedDict() + +total_keys = len(ldm_dict_keys) + +matched_ldm_keys = [] +matched_diffusers_keys = [] + +error_margin = 1e-8 + +tmp_merge_key = "TMP___MERGE" + +te_suffix = '' +proj_pattern_weight = None +proj_pattern_bias = None +text_proj_layer = None +if args.sdxl or args.ssd or args.vega: + te_suffix = '1' + ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks" + proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" + proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" + text_proj_layer = "conditioner.embedders.1.model.text_projection" +if args.refiner: + te_suffix = '1' + ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks" + proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" + proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" + text_proj_layer = "conditioner.embedders.0.model.text_projection" +if args.sd2: + te_suffix = '' + ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks" + proj_pattern_weight = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" + proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" + text_proj_layer = "cond_stage_model.model.text_projection" + +if args.sdxl or args.sd2 or args.ssd or args.refiner or args.vega: + if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys: + # d_model = int(checkpoint[prefix + "text_projection"].shape[0])) + d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0]) + elif "conditioner.embedders.1.model.text_projection.weight" in ldm_dict_keys: + # d_model = int(checkpoint[prefix + "text_projection"].shape[0])) + d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection.weight"].shape[0]) + elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys: + # d_model = int(checkpoint[prefix + "text_projection"].shape[0])) + d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0]) + else: + d_model = 1024 + + # do pre known merging + for ldm_key in ldm_dict_keys: + try: + match = re.match(proj_pattern_weight, ldm_key) + if match: + if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": + print("here") + number = int(match.group(1)) + new_val = torch.cat([ + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"], + ], dim=0) + # add to matched so we dont check them + matched_diffusers_keys.append( + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight") + matched_diffusers_keys.append( + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight") + matched_diffusers_keys.append( + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight") + # make diffusers convertable_dict + diffusers_state_dict[ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.weight"] = new_val + + # add operator + ldm_operator_map[ldm_key] = { + "cat": [ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight", + ], + } + + matched_ldm_keys.append(ldm_key) + + # text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + # text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :] + # text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :] + + # add diffusers operators + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight", + f"0:{d_model}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight", + f"{d_model}:{d_model * 2}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight", + f"{d_model * 2}:, :" + ] + } + + match = re.match(proj_pattern_bias, ldm_key) + if match: + number = int(match.group(1)) + new_val = torch.cat([ + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"], + diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"], + ], dim=0) + # add to matched so we dont check them + matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias") + matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias") + matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias") + # make diffusers convertable_dict + diffusers_state_dict[ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.bias"] = new_val + + # add operator + ldm_operator_map[ldm_key] = { + "cat": [ + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias", + f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias", + ], + } + + matched_ldm_keys.append(ldm_key) + + # add diffusers operators + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias", + f"0:{d_model}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias", + f"{d_model}:{d_model * 2}, :" + ] + } + diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"] = { + "slice": [ + f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias", + f"{d_model * 2}:, :" + ] + } + except Exception as e: + print(f"Error on key {ldm_key}") + print(e) + + # update keys + diffusers_dict_keys = list(diffusers_state_dict.keys()) + +pbar = tqdm(ldm_dict_keys, desc='Matching ldm-diffusers keys', total=total_keys) +# run through all weights and check mse between them to find matches +for ldm_key in ldm_dict_keys: + ldm_shape_tuple = ldm_state_dict[ldm_key].shape + ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple) + for diffusers_key in diffusers_dict_keys: + if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight" and diffusers_key == "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": + print("here") + + diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape + diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple) + + # That was easy. Same key + # if ldm_key == diffusers_key: + # ldm_diffusers_keymap[ldm_key] = diffusers_key + # matched_ldm_keys.append(ldm_key) + # matched_diffusers_keys.append(diffusers_key) + # break + + # if we already have this key mapped, skip it + if diffusers_key in matched_diffusers_keys: + continue + + # if reduced shapes do not match skip it + if ldm_reduced_shape_tuple != diffusers_reduced_shape_tuple: + continue + + ldm_weight = ldm_state_dict[ldm_key] + did_reduce_ldm = False + diffusers_weight = diffusers_state_dict[diffusers_key] + did_reduce_diffusers = False + + # reduce the shapes to match if they are not the same + if ldm_shape_tuple != ldm_reduced_shape_tuple: + ldm_weight = ldm_weight.view(ldm_reduced_shape_tuple) + did_reduce_ldm = True + + if diffusers_shape_tuple != diffusers_reduced_shape_tuple: + diffusers_weight = diffusers_weight.view(diffusers_reduced_shape_tuple) + did_reduce_diffusers = True + + # check to see if they match within a margin of error + mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float()) + if mse < error_margin: + ldm_diffusers_keymap[ldm_key] = diffusers_key + matched_ldm_keys.append(ldm_key) + matched_diffusers_keys.append(diffusers_key) + + if did_reduce_ldm or did_reduce_diffusers: + ldm_diffusers_shape_map[ldm_key] = (ldm_shape_tuple, diffusers_shape_tuple) + if did_reduce_ldm: + del ldm_weight + if did_reduce_diffusers: + del diffusers_weight + flush() + + break + + pbar.update(1) + +pbar.close() + +name = args.name +if args.sdxl: + name += '_sdxl' +elif args.ssd: + name += '_ssd' +elif args.vega: + name += '_vega' +elif args.refiner: + name += '_refiner' +elif args.sd2: + name += '_sd2' +else: + name += '_sd1' + +# if len(matched_ldm_keys) != len(matched_diffusers_keys): +unmatched_ldm_keys = [x for x in ldm_dict_keys if x not in matched_ldm_keys] +unmatched_diffusers_keys = [x for x in diffusers_dict_keys if x not in matched_diffusers_keys] +# has unmatched keys + +has_unmatched_keys = len(unmatched_ldm_keys) > 0 or len(unmatched_diffusers_keys) > 0 + + +def get_slices_from_string(s: str) -> tuple: + slice_strings = s.split(',') + slices = [eval(f"slice({component.strip()})") for component in slice_strings] + return tuple(slices) + + +if has_unmatched_keys: + + print( + f"Found {len(unmatched_ldm_keys)} unmatched ldm keys and {len(unmatched_diffusers_keys)} unmatched diffusers keys") + + unmatched_obj = OrderedDict() + unmatched_obj['ldm'] = OrderedDict() + unmatched_obj['diffusers'] = OrderedDict() + + print(f"Gathering info on unmatched keys") + + for key in tqdm(unmatched_ldm_keys, desc='Unmatched LDM keys'): + # get min, max, mean, std + weight = ldm_state_dict[key] + weight_min = weight.min().item() + weight_max = weight.max().item() + unmatched_obj['ldm'][key] = { + 'shape': weight.shape, + "min": weight_min, + "max": weight_max, + } + del weight + flush() + + for key in tqdm(unmatched_diffusers_keys, desc='Unmatched Diffusers keys'): + # get min, max, mean, std + weight = diffusers_state_dict[key] + weight_min = weight.min().item() + weight_max = weight.max().item() + unmatched_obj['diffusers'][key] = { + "shape": weight.shape, + "min": weight_min, + "max": weight_max, + } + del weight + flush() + + unmatched_path = os.path.join(KEYMAPS_FOLDER, f'{name}_unmatched.json') + with open(unmatched_path, 'w') as f: + f.write(json.dumps(unmatched_obj, indent=4)) + + print(f'Saved unmatched keys to {unmatched_path}') + +# save ldm remainders +remaining_ldm_values = OrderedDict() +for key in unmatched_ldm_keys: + remaining_ldm_values[key] = ldm_state_dict[key].detach().to('cpu', torch.float16) + +save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors')) +print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}') + +# do cleanup of some left overs and bugs +to_remove = [] +for ldm_key, diffusers_key in ldm_diffusers_keymap.items(): + # get rid of tmp merge keys used to slicing + if tmp_merge_key in diffusers_key or tmp_merge_key in ldm_key: + to_remove.append(ldm_key) + +for key in to_remove: + del ldm_diffusers_keymap[key] + +to_remove = [] +# remove identical shape mappings. Not sure why they exist but they do +for ldm_key, shape_list in ldm_diffusers_shape_map.items(): + # remove identical shape mappings. Not sure why they exist but they do + # convert to json string to make it easier to compare + ldm_shape = json.dumps(shape_list[0]) + diffusers_shape = json.dumps(shape_list[1]) + if ldm_shape == diffusers_shape: + to_remove.append(ldm_key) + +for key in to_remove: + del ldm_diffusers_shape_map[key] + +dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json') +save_obj = OrderedDict() +save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap +save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map +save_obj["ldm_diffusers_operator_map"] = ldm_operator_map +save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map +with open(dest_path, 'w') as f: + f.write(json.dumps(save_obj, indent=4)) + +print(f'Saved keymap to {dest_path}') diff --git a/ai-toolkit/testing/merge_in_text_encoder_adapter.py b/ai-toolkit/testing/merge_in_text_encoder_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a2983c82469a2c6e56874df8b184d30f2d23fc --- /dev/null +++ b/ai-toolkit/testing/merge_in_text_encoder_adapter.py @@ -0,0 +1,180 @@ +import os + +import torch +from transformers import T5EncoderModel, T5Tokenizer +from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import json + +# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" +# te_path = "google/flan-t5-xl" +# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" +# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" +model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch" +te_path = "google/flan-t5-xl" +te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors" +output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1" + + +print("Loading te adapter") +te_aug_sd = load_file(te_aug_path) + +print("Loading model") +is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path) + +# if "pixart" in model_path.lower(): +is_pixart = "pixart" in model_path.lower() + +pipeline_class = StableDiffusionPipeline + +# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16) + +if is_pixart: + pipeline_class = PixArtSigmaPipeline + +if is_diffusers: + sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16) +else: + sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16) + +print("Loading Text Encoder") +# Load the text encoder +te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16) + +# patch it +sd.text_encoder = te +sd.tokenizer = T5Tokenizer.from_pretrained(te_path) + +if is_pixart: + unet = sd.transformer + unet_sd = sd.transformer.state_dict() +else: + unet = sd.unet + unet_sd = sd.unet.state_dict() + + +if is_pixart: + weight_idx = 0 +else: + weight_idx = 1 + +new_cross_attn_dim = None + +# count the num of params in state dict +start_params = sum([v.numel() for v in unet_sd.values()]) + +print("Building") +attn_processor_keys = [] +if is_pixart: + transformer: Transformer2DModel = unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") +else: + attn_processor_keys = list(unet.attn_processors.keys()) + +for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith( + "attn1") else \ + unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + pass + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + te_aug_name = None + while True: + if is_pixart: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + else: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + if f"{te_aug_name}.weight" in te_aug_sd: + # increment so we dont redo it next time + weight_idx += 1 + break + else: + weight_idx += 1 + + if weight_idx > 1000: + raise ValueError("Could not find the next weight") + + orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape) + new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape) + orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape) + new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape) + + unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"] + unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"] + + if new_cross_attn_dim is None: + new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1] + + + +if is_pixart: + # copy the caption_projection weight + del unet_sd['caption_projection.linear_1.bias'] + del unet_sd['caption_projection.linear_1.weight'] + del unet_sd['caption_projection.linear_2.bias'] + del unet_sd['caption_projection.linear_2.weight'] + +print("Saving unmodified model") +sd = sd.to("cpu", torch.float16) +sd.save_pretrained( + output_path, + safe_serialization=True, +) + +# overwrite the unet +if is_pixart: + unet_folder = os.path.join(output_path, "transformer") +else: + unet_folder = os.path.join(output_path, "unet") + +# move state_dict to cpu +unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()} + +meta = OrderedDict() +meta["format"] = "pt" + +print("Patching") + +save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta) + +# load the json file +with open(os.path.join(unet_folder, "config.json"), 'r') as f: + config = json.load(f) + +config['cross_attention_dim'] = new_cross_attn_dim + +if is_pixart: + config['caption_channels'] = None + +# save it +with open(os.path.join(unet_folder, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + +print("Done") + +new_params = sum([v.numel() for v in unet_sd.values()]) + +# print new and old params with , formatted +print(f"Old params: {start_params:,}") +print(f"New params: {new_params:,}") diff --git a/ai-toolkit/testing/shrink_pixart.py b/ai-toolkit/testing/shrink_pixart.py new file mode 100644 index 0000000000000000000000000000000000000000..ad27b1a0ea38612a2a4202261ca88a7875281db1 --- /dev/null +++ b/ai-toolkit/testing/shrink_pixart.py @@ -0,0 +1,62 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +new_state_dict = {} + +# Move non-blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27 + +current_idx = 0 +for i in range(28): + if i not in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]: + # todo merge in with previous block + for name in block_names: + try: + new_state_dict_key = name.format(idx=current_idx - 1) + old_state_dict_key = name.format(idx=i) + new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5) + except KeyError: + raise KeyError(f"KeyError: {name.format(idx=current_idx)}") + else: + for name in block_names: + new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)] + current_idx += 1 + + +# make sure they are all fp16 and on cpu +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/ai-toolkit/testing/shrink_pixart2.py b/ai-toolkit/testing/shrink_pixart2.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c30cf87f38610ac23b31afdc94311fba8e3a41 --- /dev/null +++ b/ai-toolkit/testing/shrink_pixart2.py @@ -0,0 +1,81 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +new_state_dict = {} + +# Move non-blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# Blocks to keep +# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27] +keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27] + + +def weighted_merge(kept_block, removed_block, weight): + return kept_block * (1 - weight) + removed_block * weight + + +# First, copy all kept blocks to new_state_dict +for i, old_idx in enumerate(keep_blocks): + for name in block_names: + old_key = name.format(idx=old_idx) + new_key = name.format(idx=i) + new_state_dict[new_key] = state_dict[old_key].clone() + +# Then, merge information from removed blocks +for i in range(28): + if i not in keep_blocks: + # Find the nearest kept blocks + prev_kept = max([b for b in keep_blocks if b < i]) + next_kept = min([b for b in keep_blocks if b > i]) + + # Calculate the weight based on position + weight = (i - prev_kept) / (next_kept - prev_kept) + + for name in block_names: + removed_key = name.format(idx=i) + prev_new_key = name.format(idx=keep_blocks.index(prev_kept)) + next_new_key = name.format(idx=keep_blocks.index(next_kept)) + + # Weighted merge for previous kept block + new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight) + + # Weighted merge for next kept block + new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key], + 1 - weight) + +# Convert to fp16 and move to CPU +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# Save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/ai-toolkit/testing/shrink_pixart_sm.py b/ai-toolkit/testing/shrink_pixart_sm.py new file mode 100644 index 0000000000000000000000000000000000000000..8cea07bf154fdd653f1a928f3c553dc56580a828 --- /dev/null +++ b/ai-toolkit/testing/shrink_pixart_sm.py @@ -0,0 +1,84 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + original_shape = weight.shape + flattened = weight.view(-1, original_shape[-1]) + + if flattened.shape[1] <= target_size: + return weight + + U, S, V = torch.svd(flattened) + reduced = torch.mm(U[:, :target_size], torch.diag(S[:target_size])) + + if reduced.shape[1] < target_size: + padding = torch.zeros(reduced.shape[0], target_size - reduced.shape[1], device=device) + reduced = torch.cat((reduced, padding), dim=1) + + return reduced.view(original_shape[:-1] + (target_size,)) + + +def reduce_bias(bias, target_size): + bias = bias.to(device, torch.float32) + original_size = bias.shape[0] + + if original_size <= target_size: + return torch.nn.functional.pad(bias, (0, target_size - original_size)) + else: + return bias.view(-1, original_size // target_size).mean(dim=1)[:target_size] + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +source_hidden_size = 1152 +target_hidden_size = 1024 + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == source_hidden_size: + value = value[:target_hidden_size] + elif value.shape[0] == source_hidden_size * 4: + value = value[:target_hidden_size * 4] + elif value.shape[0] == source_hidden_size * 6: + value = value[:target_hidden_size * 6] + + if len(value.shape) > 1 and value.shape[ + 1] == source_hidden_size and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = value[:, :target_hidden_size] + elif len(value.shape) > 1 and value.shape[1] == source_hidden_size * 4: + value = value[:, :target_hidden_size * 4] + + elif 'bias' in key: + if value.shape[0] == source_hidden_size: + value = value[:target_hidden_size] + elif value.shape[0] == source_hidden_size * 4: + value = value[:target_hidden_size * 4] + elif value.shape[0] == source_hidden_size * 6: + value = value[:target_hidden_size * 6] + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") diff --git a/ai-toolkit/testing/shrink_pixart_sm2.py b/ai-toolkit/testing/shrink_pixart_sm2.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3304dfc72e50e445fce27ae793e7544009aa1a --- /dev/null +++ b/ai-toolkit/testing/shrink_pixart_sm2.py @@ -0,0 +1,110 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + original_shape = weight.shape + + if len(original_shape) == 1: + # For 1D tensors, simply truncate + return weight[:target_size] + + if original_shape[0] <= target_size: + return weight + + # Reshape the tensor to 2D + flattened = weight.reshape(original_shape[0], -1) + + # Perform SVD + U, S, V = torch.svd(flattened) + + # Reduce the dimensions + reduced = torch.mm(U[:target_size, :], torch.diag(S)).mm(V.t()) + + # Reshape back to the original shape with reduced first dimension + new_shape = (target_size,) + original_shape[1:] + return reduced.reshape(new_shape) + + +def reduce_bias(bias, target_size): + bias = bias.to(device, torch.float32) + return bias[:target_size] + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == 1152: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) + # reshape to (1152, -1) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 512) + value = value.view(output_shape) + else: + # value = reduce_weight(value.t(), 576).t().contiguous() + value = reduce_weight(value, 512) + pass + elif value.shape[0] == 4608: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 2048) + value = value.view(output_shape) + else: + value = reduce_weight(value, 2048) + elif value.shape[0] == 6912: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 3072) + value = value.view(output_shape) + else: + value = reduce_weight(value, 3072) + + if len(value.shape) > 1 and value.shape[ + 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction + pass + elif len(value.shape) > 1 and value.shape[1] == 4608: + value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction + pass + + elif 'bias' in key: + if value.shape[0] == 1152: + value = reduce_bias(value, 512) + elif value.shape[0] == 4608: + value = reduce_bias(value, 2048) + elif value.shape[0] == 6912: + value = reduce_bias(value, 3072) + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") \ No newline at end of file diff --git a/ai-toolkit/testing/shrink_pixart_sm3.py b/ai-toolkit/testing/shrink_pixart_sm3.py new file mode 100644 index 0000000000000000000000000000000000000000..b8756aec45b4a5cb59315ab11a5bed320d74f7ba --- /dev/null +++ b/ai-toolkit/testing/shrink_pixart_sm3.py @@ -0,0 +1,100 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + # resize so target_size is the first dimension + tmp_weight = weight.view(1, 1, weight.shape[0], weight.shape[1]) + + # use interpolate to resize the tensor + new_weight = torch.nn.functional.interpolate(tmp_weight, size=(target_size, weight.shape[1]), mode='bicubic', align_corners=True) + + # reshape back to original shape + return new_weight.view(target_size, weight.shape[1]) + + +def reduce_bias(bias, target_size): + bias = bias.view(1, 1, bias.shape[0], 1) + + new_bias = torch.nn.functional.interpolate(bias, size=(target_size, 1), mode='bicubic', align_corners=True) + + return new_bias.view(target_size) + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == 1152: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) + # reshape to (1152, -1) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 512) + value = value.view(output_shape) + else: + # value = reduce_weight(value.t(), 576).t().contiguous() + value = reduce_weight(value, 512) + pass + elif value.shape[0] == 4608: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 2048) + value = value.view(output_shape) + else: + value = reduce_weight(value, 2048) + elif value.shape[0] == 6912: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 3072) + value = value.view(output_shape) + else: + value = reduce_weight(value, 3072) + + if len(value.shape) > 1 and value.shape[ + 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction + pass + elif len(value.shape) > 1 and value.shape[1] == 4608: + value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction + pass + + elif 'bias' in key: + if value.shape[0] == 1152: + value = reduce_bias(value, 512) + elif value.shape[0] == 4608: + value = reduce_bias(value, 2048) + elif value.shape[0] == 6912: + value = reduce_bias(value, 3072) + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") \ No newline at end of file diff --git a/ai-toolkit/testing/test_bucket_dataloader.py b/ai-toolkit/testing/test_bucket_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..b453742227197a5ce49c6a8a98540d1a79fd0846 --- /dev/null +++ b/ai-toolkit/testing/test_bucket_dataloader.py @@ -0,0 +1,148 @@ +import time + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +import sys +import os +import cv2 +import random +from transformers import CLIPImageProcessor + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torchvision.transforms.functional +from toolkit.image_utils import save_tensors, show_img, show_tensors + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \ + trigger_dataloader_setup_epoch +from toolkit.config_modules import DatasetConfig +import argparse +from tqdm import tqdm + +parser = argparse.ArgumentParser() +parser.add_argument('dataset_folder', type=str, default='input') +parser.add_argument('--epochs', type=int, default=1) +parser.add_argument('--num_frames', type=int, default=1) +parser.add_argument('--output_path', type=str, default=None) + + +args = parser.parse_args() + +if args.output_path is not None: + args.output_path = os.path.abspath(args.output_path) + os.makedirs(args.output_path, exist_ok=True) + +dataset_folder = args.dataset_folder +resolution = 512 +bucket_tolerance = 64 +batch_size = 1 + +clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") + +class FakeAdapter: + def __init__(self): + self.clip_image_processor = clip_processor + + +## make fake sd +class FakeSD: + def __init__(self): + self.adapter = FakeAdapter() + self.use_raw_control_images = False + + def encode_control_in_text_embeddings(self, *args, **kwargs): + return None + + def get_bucket_divisibility(self): + return 32 + +dataset_config = DatasetConfig( + dataset_path=dataset_folder, + # clip_image_path=dataset_folder, + # square_crop=True, + resolution=resolution, + # caption_ext='json', + default_caption='default', + # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/', + buckets=True, + bucket_tolerance=bucket_tolerance, + shrink_video_to_frames=True, + num_frames=args.num_frames, + # poi='person', + # shuffle_augmentations=True, + # augmentations=[ + # { + # 'method': 'Posterize', + # 'num_bits': [(0, 4), (0, 4), (0, 4)], + # 'p': 1.0 + # }, + # + # ] +) + +dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) + + +# run through an epoch ang check sizes +dataloader_iterator = iter(dataloader) +idx = 0 +for epoch in range(args.epochs): + for batch in tqdm(dataloader): + batch: 'DataLoaderBatchDTO' + img_batch = batch.tensor + frames = 1 + if len(img_batch.shape) == 5: + frames = img_batch.shape[1] + batch_size, frames, channels, height, width = img_batch.shape + else: + batch_size, channels, height, width = img_batch.shape + + # img_batch = color_block_imgs(img_batch, neg1_1=True) + + # chunks = torch.chunk(img_batch, batch_size, dim=0) + # # put them so they are size by side + # big_img = torch.cat(chunks, dim=3) + # big_img = big_img.squeeze(0) + # + # control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) + # big_control_img = torch.cat(control_chunks, dim=3) + # big_control_img = big_control_img.squeeze(0) * 2 - 1 + # + # + # # resize control image + # big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) + # + # big_img = torch.cat([big_img, big_control_img], dim=2) + # + # min_val = big_img.min() + # max_val = big_img.max() + # + # big_img = (big_img / 2 + 0.5).clamp(0, 1) + + big_img = img_batch + # big_img = big_img.clamp(-1, 1) + if args.output_path is not None: + if len(img_batch.shape) == 5: + # video + save_tensors(big_img, os.path.join(args.output_path, f'{idx}.webp'), fps=16) + else: + save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png')) + else: + show_tensors(big_img) + + # convert to image + # img = transforms.ToPILImage()(big_img) + # + # show_img(img) + + time.sleep(0.2) + idx += 1 + # if not last epoch + if epoch < args.epochs - 1: + trigger_dataloader_setup_epoch(dataloader) + +cv2.destroyAllWindows() + +print('done') diff --git a/ai-toolkit/testing/test_ltx_dataloader.py b/ai-toolkit/testing/test_ltx_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..d27426d56959d8f100cd2ce670eeedf1641771be --- /dev/null +++ b/ai-toolkit/testing/test_ltx_dataloader.py @@ -0,0 +1,234 @@ +import time + +from torch.utils.data import DataLoader +import sys +import os +import argparse +from tqdm import tqdm +import torch +from torchvision.io import write_video +import subprocess + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.config_modules import DatasetConfig + +parser = argparse.ArgumentParser() +# parser.add_argument('dataset_folder', type=str, default='input') +parser.add_argument('dataset_folder', type=str) +parser.add_argument('--epochs', type=int, default=1) +parser.add_argument('--num_frames', type=int, default=121) +parser.add_argument('--output_path', type=str, default='output/dataset_test') + + +args = parser.parse_args() + +if args.output_path is None: + raise ValueError('output_path is required for this test script') + +if args.output_path is not None: + args.output_path = os.path.abspath(args.output_path) + os.makedirs(args.output_path, exist_ok=True) + +dataset_folder = args.dataset_folder +resolution = 512 +bucket_tolerance = 64 +batch_size = 1 +frame_rate = 24 + + +## make fake sd +class FakeSD: + def __init__(self): + self.use_raw_control_images = False + + def encode_control_in_text_embeddings(self, *args, **kwargs): + return None + + def get_bucket_divisibility(self): + return 32 + +dataset_config = DatasetConfig( + dataset_path=dataset_folder, + resolution=resolution, + default_caption='default', + buckets=True, + bucket_tolerance=bucket_tolerance, + shrink_video_to_frames=True, + num_frames=args.num_frames, + do_i2v=True, + fps=frame_rate, + do_audio=True, + debug=True, + audio_preserve_pitch=False, + audio_normalize=True + +) + +dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) + + +def _tensor_to_uint8_video(frames_fchw: torch.Tensor) -> torch.Tensor: + """ + frames_fchw: [F, C, H, W] float/uint8 + returns: [F, H, W, C] uint8 on CPU + """ + x = frames_fchw.detach() + + if x.dtype != torch.uint8: + x = x.to(torch.float32) + + # Heuristic: if negatives exist, assume [-1,1] normalization; else assume [0,1] + if torch.isfinite(x).all(): + if x.min().item() < 0.0: + x = x * 0.5 + 0.5 + x = x.clamp(0.0, 1.0) + x = (x * 255.0).round().to(torch.uint8) + else: + x = x.to(torch.uint8) + + # [F,C,H,W] -> [F,H,W,C] + x = x.permute(0, 2, 3, 1).contiguous().cpu() + return x + + +def _mux_with_ffmpeg(video_in: str, wav_in: str, mp4_out: str): + # Copy video stream, encode audio to AAC, align to shortest + subprocess.run( + [ + "ffmpeg", + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + video_in, + "-i", + wav_in, + "-c:v", + "copy", + "-c:a", + "aac", + "-shortest", + mp4_out, + ], + check=True, + ) + + +# run through an epoch ang check sizes +dataloader_iterator = iter(dataloader) +idx = 0 +for epoch in range(args.epochs): + for batch in tqdm(dataloader): + batch: 'DataLoaderBatchDTO' + img_batch = batch.tensor + frames = 1 + if len(img_batch.shape) == 5: + frames = img_batch.shape[1] + batch_size, frames, channels, height, width = img_batch.shape + else: + batch_size, channels, height, width = img_batch.shape + + # load audio + audio_tensor = batch.audio_tensor # all file items contatinated on the batch dimension + audio_data = batch.audio_data # list of raw audio data per item in the batch + + # llm save the videos here with audio and video as mp4 + fps = getattr(dataset_config, "fps", None) + if fps is None or fps <= 0: + fps = 1.0 + + # Ensure we can iterate items even if batch_size > 1 + for b in range(batch_size): + # Get per-item frames as [F,C,H,W] + if len(img_batch.shape) == 5: + frames_fchw = img_batch[b] + else: + # single image: [C,H,W] -> [1,C,H,W] + frames_fchw = img_batch[b].unsqueeze(0) + + video_uint8 = _tensor_to_uint8_video(frames_fchw) + out_mp4 = os.path.join(args.output_path, f"{idx:06d}_{b:02d}.mp4") + + # Pick audio for this item (prefer audio_data list; fallback to audio_tensor) + item_audio = None + item_sr = None + + if isinstance(audio_data, (list, tuple)) and len(audio_data) > b: + ad = audio_data[b] + if isinstance(ad, dict) and ("waveform" in ad) and ("sample_rate" in ad) and ad["waveform"] is not None: + item_audio = ad["waveform"] + item_sr = int(ad["sample_rate"]) + elif audio_tensor is not None and torch.is_tensor(audio_tensor): + # audio_tensor expected [B, C, L] (or [C,L] if batch collate differs) + if audio_tensor.dim() == 3 and audio_tensor.shape[0] > b: + item_audio = audio_tensor[b] + elif audio_tensor.dim() == 2 and b == 0: + item_audio = audio_tensor + if item_audio is not None: + # best-effort sample rate from audio_data if present but not per-item dict + if isinstance(audio_data, dict) and "sample_rate" in audio_data: + try: + item_sr = int(audio_data["sample_rate"]) + except Exception: + item_sr = None + + # Write mp4 (with audio if available) using ffmpeg muxing (torchvision audio muxing is unreliable) + tmp_video = out_mp4 + ".tmp_video.mp4" + tmp_wav = out_mp4 + ".tmp_audio.wav" + try: + # Always write video-only first + write_video(tmp_video, video_uint8, fps=float(fps), video_codec="libx264") + + if item_audio is not None and item_sr is not None and item_audio.numel() > 0: + import torchaudio + + wav = item_audio.detach() + # torchaudio.save expects [channels, samples] + if wav.dim() == 1: + wav = wav.unsqueeze(0) + torchaudio.save(tmp_wav, wav.cpu().to(torch.float32), int(item_sr)) + + # Mux to final mp4 + _mux_with_ffmpeg(tmp_video, tmp_wav, out_mp4) + else: + # No audio: just move video into place + os.replace(tmp_video, out_mp4) + + except Exception as e: + # Best-effort fallback: leave a playable video-only file + try: + if os.path.exists(tmp_video): + os.replace(tmp_video, out_mp4) + else: + write_video(out_mp4, video_uint8, fps=float(fps), video_codec="libx264") + except Exception: + raise + + if hasattr(dataset_config, 'debug') and dataset_config.debug: + print(f"Warning: failed to mux audio into mp4 for {out_mp4}: {e}") + + finally: + # Cleanup temps (don't leave separate wavs lying around) + try: + if os.path.exists(tmp_video): + os.remove(tmp_video) + except Exception: + pass + try: + if os.path.exists(tmp_wav): + os.remove(tmp_wav) + except Exception: + pass + + time.sleep(0.2) + + idx += 1 + # if not last epoch + if epoch < args.epochs - 1: + trigger_dataloader_setup_epoch(dataloader) + +print('done') diff --git a/ai-toolkit/testing/test_model_load_save.py b/ai-toolkit/testing/test_model_load_save.py new file mode 100644 index 0000000000000000000000000000000000000000..87bdfb3ef8246268f0660db6bf24822c74506c45 --- /dev/null +++ b/ai-toolkit/testing/test_model_load_save.py @@ -0,0 +1,172 @@ +import argparse +import os +# add project root to sys path +import sys + +from tqdm import tqdm + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from diffusers.loaders import LoraLoaderMixin +from safetensors.torch import load_file +from collections import OrderedDict +import json + +from toolkit.config_modules import ModelConfig +from toolkit.paths import KEYMAPS_ROOT +from toolkit.saving import convert_state_dict_to_ldm_with_mapping, get_ldm_state_dict_from_diffusers +from toolkit.stable_diffusion_model import StableDiffusion + +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +device = torch.device('cpu') +dtype = torch.float32 + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path an LDM model' +) + +parser.add_argument( + '--is_xl', + action='store_true', + help='Is the model an XL model' +) + +parser.add_argument( + '--is_v2', + action='store_true', + help='Is the model a v2 model' +) + +args = parser.parse_args() + +find_matches = False + +print("Loading model") +state_dict_file_1 = load_file(args.file_1[0]) +state_dict_1_keys = list(state_dict_file_1.keys()) + +print("Loading model into diffusers format") +model_config = ModelConfig( + name_or_path=args.file_1[0], + is_xl=args.is_xl +) +sd = StableDiffusion( + model_config=model_config, + device=device, +) +sd.load_model() + +# load our base +base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors') +mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json') + +print("Converting model back to LDM") +version_string = '1' +if args.is_v2: + version_string = '2' +if args.is_xl: + version_string = 'sdxl' +# convert the state dict +state_dict_file_2 = get_ldm_state_dict_from_diffusers( + sd.state_dict(), + version_string, + device='cpu', + dtype=dtype +) + +# state_dict_file_2 = load_file(args.file_2[0]) + +state_dict_2_keys = list(state_dict_file_2.keys()) +keys_in_both = [] + +keys_not_in_state_dict_2 = [] +for key in state_dict_1_keys: + if key not in state_dict_2_keys: + keys_not_in_state_dict_2.append(key) + +keys_not_in_state_dict_1 = [] +for key in state_dict_2_keys: + if key not in state_dict_1_keys: + keys_not_in_state_dict_1.append(key) + +keys_in_both = [] +for key in state_dict_1_keys: + if key in state_dict_2_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_state_dict_2.sort() +keys_not_in_state_dict_1.sort() +keys_in_both.sort() + +if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0: + print("All keys match!") + print("Checking values...") + mismatch_keys = [] + loss = torch.nn.MSELoss() + tolerance = 1e-6 + for key in tqdm(keys_in_both): + if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance: + print(f"Values for key {key} don't match!") + print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}") + mismatch_keys.append(key) + + if len(mismatch_keys) == 0: + print("All values match!") + else: + print("Some valued font match!") + print(mismatch_keys) + mismatched_path = os.path.join(project_root, 'config', 'mismatch.json') + with open(mismatched_path, 'w') as f: + f.write(json.dumps(mismatch_keys, indent=4)) + exit(0) + +else: + print("Keys don't match!, generating info...") + +json_data = { + "both": keys_in_both, + "not_in_state_dict_2": keys_not_in_state_dict_2, + "not_in_state_dict_1": keys_not_in_state_dict_1 +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_state_dict_1: + remaining_diffusers_values[key] = state_dict_file_2[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_state_dict_2: + remaining_ldm_values[key] = state_dict_file_1[key] + +# print(json_data) + + +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') +state_dict_1_filename = os.path.basename(args.file_1[0]) +# state_dict_2_filename = os.path.basename(args.file_2[0]) +# save key names for each in own file +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_1_keys, indent=4)) + +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f: + f.write(json.dumps(state_dict_2_keys, indent=4)) + +with open(json_save_path, 'w') as f: + f.write(json_data) diff --git a/ai-toolkit/testing/test_vae.py b/ai-toolkit/testing/test_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..463ab555896ea953b042380fb464c7beef6f5933 --- /dev/null +++ b/ai-toolkit/testing/test_vae.py @@ -0,0 +1,130 @@ +import argparse +import os +from PIL import Image +import torch +from torchvision.transforms import Resize, ToTensor +from diffusers import AutoencoderKL +from pytorch_fid import fid_score +from skimage.metrics import peak_signal_noise_ratio as psnr +import lpips +from tqdm import tqdm +from torchvision import transforms + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def load_images(folder_path): + images = [] + for filename in os.listdir(folder_path): + if filename.lower().endswith(('.png', '.jpg', '.jpeg')): + img_path = os.path.join(folder_path, filename) + images.append(img_path) + return images + + +def paramiter_count(model): + state_dict = model.state_dict() + paramiter_count = 0 + for key in state_dict: + paramiter_count += torch.numel(state_dict[key]) + return int(paramiter_count) + + +def calculate_metrics(vae, images, max_imgs=-1, save_output=False): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + vae = vae.to(device) + lpips_model = lpips.LPIPS(net='alex').to(device) + + rfid_scores = [] + psnr_scores = [] + lpips_scores = [] + + # transform = transforms.Compose([ + # transforms.Resize(256, antialias=True), + # transforms.CenterCrop(256) + # ]) + # needs values between -1 and 1 + to_tensor = ToTensor() + + # remove _reconstructed.png files + images = [img for img in images if not img.endswith("_reconstructed.png")] + + if max_imgs > 0 and len(images) > max_imgs: + images = images[:max_imgs] + + for img_path in tqdm(images): + try: + img = Image.open(img_path).convert('RGB') + # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device) + img_tensor = to_tensor(img).unsqueeze(0).to(device) + img_tensor = 2 * img_tensor - 1 + # if width or height is not divisible by 8, crop it + if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0: + img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8] + + except Exception as e: + print(f"Error processing {img_path}: {e}") + continue + + + with torch.no_grad(): + reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample + + # Calculate rFID + # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed) + # rfid_scores.append(rfid) + + # Calculate PSNR + psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy()) + psnr_scores.append(psnr_val) + + # Calculate LPIPS + lpips_val = lpips_model(img_tensor, reconstructed).item() + lpips_scores.append(lpips_val) + + # avg_rfid = sum(rfid_scores) / len(rfid_scores) + avg_rfid = 0 + avg_psnr = sum(psnr_scores) / len(psnr_scores) + avg_lpips = sum(lpips_scores) / len(lpips_scores) + + if save_output: + filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] + folder = os.path.dirname(img_path) + save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png") + reconstructed = (reconstructed + 1) / 2 + reconstructed = reconstructed.clamp(0, 1) + reconstructed = transforms.ToPILImage()(reconstructed[0].cpu()) + reconstructed.save(save_path) + + return avg_rfid, avg_psnr, avg_lpips + + +def main(): + parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions") + parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") + parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") + parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") + # boolean store true + parser.add_argument("--save_output", action="store_true", help="Save the output images") + args = parser.parse_args() + + if os.path.isfile(args.vae_path): + vae = AutoencoderKL.from_single_file(args.vae_path) + else: + try: + vae = AutoencoderKL.from_pretrained(args.vae_path) + except: + vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") + vae.eval() + vae = vae.to(device) + print(f"Model has {paramiter_count(vae)} parameters") + images = load_images(args.image_folder) + + avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output) + + # print(f"Average rFID: {avg_rfid}") + print(f"Average PSNR: {avg_psnr}") + print(f"Average LPIPS: {avg_lpips}") + + +if __name__ == "__main__": + main() diff --git a/ai-toolkit/testing/test_vae_cycle.py b/ai-toolkit/testing/test_vae_cycle.py new file mode 100644 index 0000000000000000000000000000000000000000..175e8f8fa5cdb4cb652225f4d95e7a2cbb04fd29 --- /dev/null +++ b/ai-toolkit/testing/test_vae_cycle.py @@ -0,0 +1,112 @@ +import os + +import torch +from safetensors.torch import load_file +from collections import OrderedDict +from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 +vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors' + +find_matches = False + +state_dict_ldm = load_file(vae_path) +diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device) + +ldm_keys = state_dict_ldm.keys() + +matched_keys = {} +duplicated_keys = { + +} + +if find_matches: + # find values that match with a very low mse + for ldm_key in ldm_keys: + ldm_value = state_dict_ldm[ldm_key] + for diffusers_key in list(diffusers_vae.state_dict().keys()): + diffusers_value = diffusers_vae.state_dict()[diffusers_key] + if diffusers_key in vae_keys_squished_on_diffusers: + diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1) + # if they are not same shape, skip + if ldm_value.shape != diffusers_value.shape: + continue + mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value) + if mse < 1e-6: + if ldm_key in list(matched_keys.keys()): + print(f'{ldm_key} already matched to {matched_keys[ldm_key]}') + if ldm_key in duplicated_keys: + duplicated_keys[ldm_key].append(diffusers_key) + else: + duplicated_keys[ldm_key] = [diffusers_key] + continue + matched_keys[ldm_key] = diffusers_key + is_matched = True + break + + print(f'Found {len(matched_keys)} matches') + +dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae) +dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys()) +keys_in_both = [] + +keys_not_in_diffusers = [] +for key in ldm_keys: + if key not in dif_to_ldm_state_dict_keys: + keys_not_in_diffusers.append(key) + +keys_not_in_ldm = [] +for key in dif_to_ldm_state_dict_keys: + if key not in ldm_keys: + keys_not_in_ldm.append(key) + +keys_in_both = [] +for key in ldm_keys: + if key in dif_to_ldm_state_dict_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_diffusers.sort() +keys_not_in_ldm.sort() +keys_in_both.sort() + +# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}') +# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}') +# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}') + +json_data = { + "both": keys_in_both, + "ldm": keys_not_in_diffusers, + "diffusers": keys_not_in_ldm +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_ldm: + remaining_diffusers_values[key] = dif_to_ldm_state_dict[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_diffusers: + remaining_ldm_values[key] = state_dict_ldm[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') + +with open(json_save_path, 'w') as f: + f.write(json_data) +if find_matches: + with open(json_matched_save_path, 'w') as f: + f.write(json.dumps(matched_keys, indent=4)) + with open(json_duped_save_path, 'w') as f: + f.write(json.dumps(duplicated_keys, indent=4)) diff --git a/ai-toolkit/toolkit/__init__.py b/ai-toolkit/toolkit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae76ce004f4922af10c611b648c4118f53d4a69 --- /dev/null +++ b/ai-toolkit/toolkit/__init__.py @@ -0,0 +1,24 @@ +import importlib +import logging + + +def force_hf_hub_progress_bars(): + hf_tqdm = importlib.import_module('huggingface_hub.utils.tqdm') + + original_is_tqdm_disabled = hf_tqdm.is_tqdm_disabled + if getattr(original_is_tqdm_disabled, '_aitk_forced_progress', False): + return + + def is_tqdm_disabled(log_level): + disabled = original_is_tqdm_disabled(log_level) + # UI jobs log to files, so stderr is not a TTY. Keep HF download bars + # visible while preserving normal carriage-return progress updates. + if disabled is None and log_level != logging.NOTSET: + return False + return disabled + + is_tqdm_disabled._aitk_forced_progress = True + hf_tqdm.is_tqdm_disabled = is_tqdm_disabled + + +force_hf_hub_progress_bars() diff --git a/ai-toolkit/toolkit/accelerator.py b/ai-toolkit/toolkit/accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..0736f0167c0f74ee327cea87653b57c1a1d6ff6f --- /dev/null +++ b/ai-toolkit/toolkit/accelerator.py @@ -0,0 +1,20 @@ +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + +global_accelerator = None + + +def get_accelerator() -> Accelerator: + global global_accelerator + if global_accelerator is None: + global_accelerator = Accelerator() + return global_accelerator + +def unwrap_model(model): + try: + accelerator = get_accelerator() + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + except Exception as e: + pass + return model diff --git a/ai-toolkit/toolkit/advanced_prompt_embeds.py b/ai-toolkit/toolkit/advanced_prompt_embeds.py new file mode 100644 index 0000000000000000000000000000000000000000..b86adeb613f50a32361c63cf7837a5fe3b07bebb --- /dev/null +++ b/ai-toolkit/toolkit/advanced_prompt_embeds.py @@ -0,0 +1,195 @@ +import os +import torch +from safetensors.torch import load_file, save_file + + +class AdvancedPromptEmbeds: + """ + Flexible container for prompt embedding tensors. + + Each value passed in must be a list of tensors, where each item in the + list corresponds to a single item in the batch (list length == batch size). + Do not store more than one tensor per batch item under the same key — if + you need multiple tensors per batch item, give them different key names. + + Usage: + pe = AdvancedPromptEmbeds( + prompt_embeds=[t0, t1, t2], # one tensor per batch item + pooled_embeds=[p0, p1, p2], + ) + + pe.prompt_embeds # -> [t0, t1, t2] + pe['prompt_embeds'] # -> [t0, t1, t2] + pe.keys() # -> ['prompt_embeds', 'pooled_embeds'] + + # add more after init + pe.extra = [e0, e1, e2] + pe['extra2'] = [e0, e1, e2] + pe.set('extra3', [e0, e1, e2]) + pe.update(extra4=[e0, e1, e2]) + """ + + def __init__(self, **kwargs): + self._store = {} + self._frozen_dtype_keys = [] + for key, value in kwargs.items(): + if not isinstance(value, list): + value = [value] + self._store[key] = value + + @property + def frozen_dtype_keys(self): + return self._frozen_dtype_keys + + @frozen_dtype_keys.setter + def frozen_dtype_keys(self, keys): + self._frozen_dtype_keys = list(keys) if keys else [] + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + store = self.__dict__.get("_store", {}) + if name in store: + return store[name] + raise AttributeError(f"{type(self).__name__!s} has no attribute {name!r}") + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + return + cls_attr = getattr(type(self), name, None) + if isinstance(cls_attr, property): + super().__setattr__(name, value) + return + if not isinstance(value, list): + value = [value] + self._store[name] = value + + def set(self, key, value): + if not isinstance(value, list): + value = [value] + self._store[key] = value + + def update(self, **kwargs): + for key, value in kwargs.items(): + if not isinstance(value, list): + value = [value] + self._store[key] = value + + def keys(self): + return list(self._store.keys()) + + def __getitem__(self, key): + return self._store[key] + + def __setitem__(self, key, value): + if not isinstance(value, list): + value = [value] + self._store[key] = value + + def __contains__(self, key): + return key in self._store + + def to(self, *args, **kwargs): + frozen = set(self._frozen_dtype_keys) + if frozen: + no_dtype_args = [a for a in args if not isinstance(a, torch.dtype)] + no_dtype_kwargs = {k: v for k, v in kwargs.items() if k != "dtype"} + new_pe = AdvancedPromptEmbeds() + new_pe._frozen_dtype_keys = list(self._frozen_dtype_keys) + for key, value in self._store.items(): + if key in frozen: + new_pe._store[key] = [ + v.to(*no_dtype_args, **no_dtype_kwargs) for v in value + ] + else: + new_pe._store[key] = [v.to(*args, **kwargs) for v in value] + return new_pe + + def detach(self): + new_pe = AdvancedPromptEmbeds() + new_pe._frozen_dtype_keys = list(self._frozen_dtype_keys) + for key, value in self._store.items(): + new_pe._store[key] = [v.detach() for v in value] + return new_pe + + def clone(self): + new_pe = AdvancedPromptEmbeds() + new_pe._frozen_dtype_keys = list(self._frozen_dtype_keys) + for key, value in self._store.items(): + new_pe._store[key] = [v.clone() for v in value] + return new_pe + + def expand_to_batch(self, batch_size): + new_pe = AdvancedPromptEmbeds() + new_pe._frozen_dtype_keys = list(self._frozen_dtype_keys) + for key, value in self._store.items(): + if len(value) == 1: + new_pe._store[key] = value * batch_size + elif len(value) == batch_size: + new_pe._store[key] = value + else: + raise ValueError( + f"Cannot expand key {key!r}: expected list of length 1 or {batch_size}, got {len(value)}" + ) + return new_pe + + def save(self, path): + data = {} + metadata = {"class_name": self.__class__.__name__} + for key, value in self._store.items(): + if len(value) != 1: + raise ValueError( + f"Cannot save key {key!r}: expected list of length 1, got {len(value)}" + ) + data[key] = value[0] + os.makedirs(os.path.dirname(path), exist_ok=True) + save_file(data, path, metadata=metadata) + + @classmethod + def load(cls, path=None): + if path is not None: + loaded = load_file(path) + else: + raise ValueError("Must provide a path") + + data = {} + for key in loaded.keys(): + data[key] = loaded[key] + + return cls(**data) + + @classmethod + def concat_prompt_embeds( + cls, prompt_embeds: list["AdvancedPromptEmbeds"], padding_side: str = "right" + ): + embeds = {} + frozen = [] + for pe in prompt_embeds: + for key in pe.keys(): + if key not in embeds: + embeds[key] = [] + embeds[key].extend(pe[key]) + for k in pe.frozen_dtype_keys: + if k not in frozen: + frozen.append(k) + out = cls(**embeds) + out.frozen_dtype_keys = frozen + return out + + @classmethod + def split_prompt_embeds(cls, concatenated: "AdvancedPromptEmbeds", num_parts=None): + if num_parts is None: + # use length of first item as num_parts + num_parts = len(concatenated[concatenated.keys()[0]]) + split_embeds = [cls() for _ in range(num_parts)] + for pe in split_embeds: + pe.frozen_dtype_keys = list(concatenated.frozen_dtype_keys) + for key in concatenated.keys(): + values = concatenated[key] + if len(values) != num_parts: + raise ValueError( + f"Cannot split key {key!r}: expected list of length {num_parts}, got {len(values)}" + ) + for i in range(num_parts): + split_embeds[i]._store[key] = [values[i]] diff --git a/ai-toolkit/toolkit/assistant_lora.py b/ai-toolkit/toolkit/assistant_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeca968ad6c6d5b403f3786eea81efc33944c94 --- /dev/null +++ b/ai-toolkit/toolkit/assistant_lora.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING +from toolkit.config_modules import NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from safetensors.torch import load_file + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork: + if not sd.is_flux: + raise ValueError("Only Flux models can load assistant adapters currently.") + pipe = sd.pipeline + print(f"Loading assistant adapter from {adapter_path}") + adapter_name = adapter_path.split("/")[-1].split(".")[0] + lora_state_dict = load_file(adapter_path) + + linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0]) + # linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item()) + linear_alpha = linear_dim + transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict + # get dim and scale + network_config = NetworkConfig( + linear=linear_dim, + linear_alpha=linear_alpha, + transformer_only=transformer_only, + ) + + network = LoRASpecialNetwork( + text_encoder=pipe.text_encoder, + unet=pipe.transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + is_flux=True, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_assistant_adapter=True + ) + network.apply_to( + pipe.text_encoder, + pipe.transformer, + apply_text_encoder=False, + apply_unet=True + ) + network.force_to(sd.device_torch, dtype=sd.torch_dtype) + network.eval() + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + network.is_active = True + + return network diff --git a/ai-toolkit/toolkit/audio/album_artwork.py b/ai-toolkit/toolkit/audio/album_artwork.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ef4b397b83d3e367db0da7ba3ee2ba451ac0ef --- /dev/null +++ b/ai-toolkit/toolkit/audio/album_artwork.py @@ -0,0 +1,119 @@ +import io +import os +import numpy as np +import av +from PIL import Image, ImageDraw + + +ARTWORK_DIR = os.path.dirname(os.path.abspath(__file__)) +BACKGROUND_PATH = os.path.join(ARTWORK_DIR, "ostris_logo.jpg") +WAVEFORM_COLOR = (0xFB, 0xBF, 0x24, 230) # #fbbf24 at 90% opacity +ARTWORK_SIZE = 1024 + + +def load_waveform(audio_path: str, num_samples: int = 512) -> np.ndarray: + """Load audio and return a downsampled waveform envelope using PyAV.""" + container = av.open(audio_path) + stream = container.streams.audio[0] + stream.codec_context.thread_type = "AUTO" + + frames = [] + for frame in container.decode(stream): + arr = frame.to_ndarray() + # mix down to mono + if arr.ndim > 1: + arr = arr.mean(axis=0) + frames.append(arr) + container.close() + + audio = np.concatenate(frames) + + # downsample to num_samples bins by taking max absolute value per bin + bin_size = len(audio) // num_samples + if bin_size == 0: + bin_size = 1 + trimmed = audio[: bin_size * num_samples] + bins = trimmed.reshape(num_samples, bin_size) + envelope = np.max(np.abs(bins), axis=1) + + # normalize to 0-1 + peak = envelope.max() + if peak > 0: + envelope = envelope / peak + return envelope + + +def create_artwork(waveform: np.ndarray, size: int = ARTWORK_SIZE) -> Image.Image: + """Create album artwork with logo background and waveform overlay.""" + bg = Image.open(BACKGROUND_PATH).convert("RGBA").resize((size, size), Image.LANCZOS) + + # draw waveform on separate overlay for alpha compositing + wave_overlay = Image.new("RGBA", (size, size), (0, 0, 0, 0)) + draw = ImageDraw.Draw(wave_overlay) + + num_bars = len(waveform) + padding = int(size * 0.02) + draw_w = size - 2 * padding + bar_width = max(1, draw_w / num_bars) + center_y = size // 2 + + max_amp = (size // 2) * 0.85 # leave a little margin + + for i, amp in enumerate(waveform): + x = padding + i * bar_width + h = amp * max_amp + y_top = center_y - h + y_bot = center_y + h + draw.rectangle( + [x, y_top, x + bar_width - 1, y_bot], + fill=WAVEFORM_COLOR, + ) + + bg = Image.alpha_composite(bg, wave_overlay) + return bg.convert("RGB") + + +def add_album_artwork(song_path: str) -> None: + """Add album artwork with waveform visualization to an MP3 file.""" + from mutagen.id3 import ID3, APIC, ID3NoHeaderError + + if not os.path.isfile(song_path): + raise FileNotFoundError(f"Audio file not found: {song_path}") + + waveform = load_waveform(song_path) + artwork = create_artwork(waveform) + + # encode artwork to JPEG bytes in memory + buf = io.BytesIO() + artwork.save(buf, format="JPEG", quality=85) + artwork_data = buf.getvalue() + + # embed into MP3 via mutagen ID3 tags + try: + tags = ID3(song_path) + except ID3NoHeaderError: + tags = ID3() + + tags.delall("APIC") + tags.add( + APIC( + encoding=3, # UTF-8 + mime="image/jpeg", + type=3, # front cover + desc="Cover", + data=artwork_data, + ) + ) + tags.save(song_path, v2_version=3) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Add album artwork with waveform to an MP3 file" + ) + parser.add_argument("mp3", help="Path to the MP3 file") + args = parser.parse_args() + + add_album_artwork(args.mp3) diff --git a/ai-toolkit/toolkit/audio/make_video.py b/ai-toolkit/toolkit/audio/make_video.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc95561f4753a6064972e4e6e7794364bc806a4 --- /dev/null +++ b/ai-toolkit/toolkit/audio/make_video.py @@ -0,0 +1,149 @@ +import os +import numpy as np +import av +from PIL import Image, ImageDraw + + +ARTWORK_DIR = os.path.dirname(os.path.abspath(__file__)) +BACKGROUND_PATH = os.path.join(ARTWORK_DIR, "ostris_logo.jpg") +WAVEFORM_COLOR = (0xFB, 0xBF, 0x24, 230) # #fbbf24 at 90% opacity +ARTWORK_SIZE = 1024 + + +def load_waveform(audio_path: str, num_samples: int = 512) -> np.ndarray: + """Load audio and return a downsampled waveform envelope using PyAV.""" + container = av.open(audio_path) + stream = container.streams.audio[0] + stream.codec_context.thread_type = "AUTO" + + frames = [] + for frame in container.decode(stream): + arr = frame.to_ndarray() + # mix down to mono + if arr.ndim > 1: + arr = arr.mean(axis=0) + frames.append(arr) + container.close() + + audio = np.concatenate(frames) + + # downsample to num_samples bins by taking max absolute value per bin + bin_size = len(audio) // num_samples + if bin_size == 0: + bin_size = 1 + trimmed = audio[: bin_size * num_samples] + bins = trimmed.reshape(num_samples, bin_size) + envelope = np.max(np.abs(bins), axis=1) + + # normalize to 0-1 + peak = envelope.max() + if peak > 0: + envelope = envelope / peak + return envelope + + +def create_artwork(waveform: np.ndarray, size: int = ARTWORK_SIZE) -> Image.Image: + """Create album artwork with logo background and waveform overlay.""" + bg = Image.open(BACKGROUND_PATH).convert("RGBA").resize((size, size), Image.LANCZOS) + + # draw waveform on separate overlay for alpha compositing + wave_overlay = Image.new("RGBA", (size, size), (0, 0, 0, 0)) + draw = ImageDraw.Draw(wave_overlay) + + num_bars = len(waveform) + padding = int(size * 0.02) + draw_w = size - 2 * padding + bar_width = max(1, draw_w / num_bars) + center_y = size // 2 + + max_amp = (size // 2) * 0.85 # leave a little margin + + for i, amp in enumerate(waveform): + x = padding + i * bar_width + h = amp * max_amp + y_top = center_y - h + y_bot = center_y + h + draw.rectangle( + [x, y_top, x + bar_width - 1, y_bot], + fill=WAVEFORM_COLOR, + ) + + bg = Image.alpha_composite(bg, wave_overlay) + return bg.convert("RGB") + + +def make_video(song_path: str, video_size: int = 512) -> str: + """Create an MP4 video with album artwork as a static image for the duration of the audio.""" + if not os.path.isfile(song_path): + raise FileNotFoundError(f"Audio file not found: {song_path}") + + waveform = load_waveform(song_path) + artwork = create_artwork(waveform) + artwork = artwork.resize((video_size, video_size), Image.LANCZOS) + + # get audio duration + container = av.open(song_path) + duration = float(container.duration) / av.time_base + container.close() + + # output path: same name as input but .mp4, in the same directory + base, _ = os.path.splitext(song_path) + output_path = base + ".mp4" + + fps = 1 # static image, 1 fps is enough + total_frames = max(1, int(duration * fps)) + + # convert artwork to numpy array for video encoding + frame_data = np.array(artwork) + + out_container = av.open(output_path, mode="w") + video_stream = out_container.add_stream("libx264", rate=fps) + video_stream.width = video_size + video_stream.height = video_size + video_stream.pix_fmt = "yuv420p" + + for _ in range(total_frames): + frame = av.VideoFrame.from_ndarray(frame_data, format="rgb24") + for packet in video_stream.encode(frame): + out_container.mux(packet) + + # flush + for packet in video_stream.encode(): + out_container.mux(packet) + + out_container.close() + + # mux audio into the video using ffmpeg via subprocess + import subprocess + final_path = base + "_final.mp4" + subprocess.run( + [ + "ffmpeg", "-y", + "-i", output_path, + "-i", song_path, + "-c:v", "copy", + "-c:a", "aac", + "-shortest", + final_path, + ], + check=True, + capture_output=True, + ) + + # replace silent video with final muxed version + os.replace(final_path, output_path) + + return output_path + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Create an MP4 video with album artwork from an audio file" + ) + parser.add_argument("audio", help="Path to the audio file") + args = parser.parse_args() + + out = make_video(args.audio) + print(f"Created video: {out}") diff --git a/ai-toolkit/toolkit/audio/preserve_pitch.py b/ai-toolkit/toolkit/audio/preserve_pitch.py new file mode 100644 index 0000000000000000000000000000000000000000..501c139dc30020775a437f3edea8dd59ce02abcf --- /dev/null +++ b/ai-toolkit/toolkit/audio/preserve_pitch.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn.functional as F +import torchaudio + +def time_stretch_preserve_pitch(waveform: torch.Tensor, sample_rate: int, target_samples: int) -> torch.Tensor: + """ + waveform: [C, L] float tensor (CPU or GPU) + returns: [C, target_samples] float tensor + Pitch-preserving time stretch to match target_samples. + """ + + + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + + waveform = waveform.to(torch.float32) + + src_len = waveform.shape[-1] + if src_len == 0 or target_samples <= 0: + return waveform[..., :0] + + if src_len == target_samples: + return waveform + + # rate > 1.0 speeds up (shorter), rate < 1.0 slows down (longer) + rate = float(src_len) / float(target_samples) + + # Use sample_rate to pick STFT params + win_seconds = 0.046 + hop_seconds = 0.0115 + + n_fft_target = int(sample_rate * win_seconds) + n_fft = 1 << max(8, int(math.floor(math.log2(max(256, n_fft_target))))) # >=256, pow2 + win_length = n_fft + hop_length = max(64, int(sample_rate * hop_seconds)) + hop_length = min(hop_length, win_length // 2) + + window = torch.hann_window(win_length, device=waveform.device, dtype=waveform.dtype) + + stft = torch.stft( + waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + return_complex=True, + ) # [C, F, T] complex + + # IMPORTANT: n_freq must match STFT's frequency bins (n_fft//2 + 1) + stretcher = torchaudio.transforms.TimeStretch( + n_freq=stft.shape[-2], + hop_length=hop_length, + fixed_rate=rate, + ).to(waveform.device) + + stft_stretched = stretcher(stft) # [C, F, T'] + + stretched = torch.istft( + stft_stretched, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + length=target_samples, + ) + + if stretched.shape[-1] > target_samples: + stretched = stretched[..., :target_samples] + elif stretched.shape[-1] < target_samples: + stretched = F.pad(stretched, (0, target_samples - stretched.shape[-1])) + + return stretched diff --git a/ai-toolkit/toolkit/basic.py b/ai-toolkit/toolkit/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..3784445717de37b62ce5d749ce17a01b43d2fd35 --- /dev/null +++ b/ai-toolkit/toolkit/basic.py @@ -0,0 +1,70 @@ +import gc +import os + +import torch + + +def value_map(inputs, min_in, max_in, min_out, max_out): + return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out + + +def flush(garbage_collect=True): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # if is mps, also clear the mps cache + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + if garbage_collect: + gc.collect() + + +def get_mean_std(tensor): + if len(tensor.shape) == 3: + tensor = tensor.unsqueeze(0) + elif len(tensor.shape) != 4: + raise Exception("Expected tensor of shape (batch_size, channels, width, height)") + mean, variance = torch.mean( + tensor, dim=[2, 3], keepdim=True + ), torch.var( + tensor, dim=[2, 3], + keepdim=True + ) + std = torch.sqrt(variance + 1e-5) + return mean, std + + +def adain(content_features, style_features): + # Assumes that the content and style features are of shape (batch_size, channels, width, height) + + dims = [2, 3] + if len(content_features.shape) == 3: + # content_features = content_features.unsqueeze(0) + # style_features = style_features.unsqueeze(0) + dims = [1] + + # Step 1: Calculate mean and variance of content features + content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features, + dim=dims, + keepdim=True) + # Step 2: Calculate mean and variance of style features + style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims, + keepdim=True) + + # Step 3: Normalize content features + content_std = torch.sqrt(content_var + 1e-5) + normalized_content = (content_features - content_mean) / content_std + + # Step 4: Scale and shift normalized content with style's statistics + style_std = torch.sqrt(style_var + 1e-5) + stylized_content = normalized_content * style_std + style_mean + + return stylized_content + +def get_quick_signature_string(file_path): + try: + file_stats = os.stat(file_path) + # Combine size and mtime into a single string + return f"{file_stats.st_size}:{int(file_stats.st_mtime)}" + except Exception as e: + print(f"Error accessing file {file_path}: {e}") + return None \ No newline at end of file diff --git a/ai-toolkit/toolkit/buckets.py b/ai-toolkit/toolkit/buckets.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b12ce8fa1ab6fdfcd4be984dab099a154b66df --- /dev/null +++ b/ai-toolkit/toolkit/buckets.py @@ -0,0 +1,48 @@ +import math +from typing import TypedDict + + +class BucketResolution(TypedDict): + width: int + height: int + + +def get_resolution(width, height): + num_pixels = width * height + # determine same number of pixels for square image + square_resolution = int(num_pixels**0.5) + return square_resolution + + +def get_bucket_for_image_size( + width: int, height: int, resolution: int = 512, divisibility: int = 8 +) -> BucketResolution: + total_pixels = width * height + max_pixels = resolution * resolution + + target_pixels = min(total_pixels, max_pixels) + + scaler = (target_pixels / total_pixels) ** 0.5 + w_raw = (width * scaler) / divisibility + h_raw = (height * scaler) / divisibility + + candidates = [ + (math.floor(w_raw) * divisibility, math.floor(h_raw) * divisibility), + (math.floor(w_raw) * divisibility, math.ceil(h_raw) * divisibility), + (math.ceil(w_raw) * divisibility, math.floor(h_raw) * divisibility), + (math.ceil(w_raw) * divisibility, math.ceil(h_raw) * divisibility), + ] + capped = [(w, h) for w, h in candidates if w > 0 and h > 0 and w * h <= max_pixels] + if not capped: + capped = [ + ( + max(divisibility, math.floor(w_raw) * divisibility), + max(divisibility, math.floor(h_raw) * divisibility), + ) + ] + + new_width, new_height = min( + capped, key=lambda wh: abs(wh[0] * wh[1] - target_pixels) + ) + + return {"width": new_width, "height": new_height} diff --git a/ai-toolkit/toolkit/civitai.py b/ai-toolkit/toolkit/civitai.py new file mode 100644 index 0000000000000000000000000000000000000000..ef505ad833f951470eb2e6a9c7b26059a6509604 --- /dev/null +++ b/ai-toolkit/toolkit/civitai.py @@ -0,0 +1,217 @@ +from toolkit.paths import MODELS_PATH +import requests +import os +import json +import tqdm + + +class ModelCache: + def __init__(self): + self.raw_cache = {} + self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json') + if os.path.exists(self.cache_path): + with open(self.cache_path, 'r') as f: + all_cache = json.load(f) + if 'models' in all_cache: + self.raw_cache = all_cache['models'] + else: + self.raw_cache = all_cache + + def get_model_path(self, model_id: int, model_version_id: int = None): + if str(model_id) not in self.raw_cache: + return None + if model_version_id is None: + # get latest version + model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()]) + if model_version_id is None: + return None + model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path'] + # check if model path exists + if not os.path.exists(model_path): + # remove version from cache + del self.raw_cache[str(model_id)][str(model_version_id)] + self.save() + return None + return model_path + else: + if str(model_version_id) not in self.raw_cache[str(model_id)]: + return None + model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path'] + # check if model path exists + if not os.path.exists(model_path): + # remove version from cache + del self.raw_cache[str(model_id)][str(model_version_id)] + self.save() + return None + return model_path + + def update_cache(self, model_id: int, model_version_id: int, model_path: str): + if str(model_id) not in self.raw_cache: + self.raw_cache[str(model_id)] = {} + if str(model_version_id) not in self.raw_cache[str(model_id)]: + self.raw_cache[str(model_id)][str(model_version_id)] = {} + self.raw_cache[str(model_id)][str(model_version_id)] = { + 'model_path': model_path + } + self.save() + + def save(self): + if not os.path.exists(os.path.dirname(self.cache_path)): + os.makedirs(os.path.dirname(self.cache_path), exist_ok=True) + all_cache = {'models': {}} + if os.path.exists(self.cache_path): + # load it first + with open(self.cache_path, 'r') as f: + all_cache = json.load(f) + + all_cache['models'] = self.raw_cache + + with open(self.cache_path, 'w') as f: + json.dump(all_cache, f, indent=2) + + +def get_model_download_info(model_id: int, model_version_id: int = None): + # curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \ + # -H "Content-Type: application/json" \ + # -X GET + print( + f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}") + endpoint = f"https://civitai.com/api/v1/models/{model_id}" + + # get the json + response = requests.get(endpoint) + response.raise_for_status() + model_data = response.json() + + model_version = None + + # go through versions and get the top one if one is not set + for version in model_data['modelVersions']: + if model_version_id is not None: + if str(version['id']) == str(model_version_id): + model_version = version + break + else: + # get first version + model_version = version + break + + if model_version is None: + raise ValueError( + f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}") + + model_file = None + # go through files and prefer fp16 safetensors + # "metadata": { + # "fp": "fp16", + # "size": "pruned", + # "format": "SafeTensor" + # }, + # todo check pickle scans and skip if not good + # try to get fp16 safetensor + for file in model_version['files']: + if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor': + model_file = file + break + + if model_file is None: + # try to get primary + for file in model_version['files']: + if file['primary']: + model_file = file + break + + if model_file is None: + # try to get any safetensor + for file in model_version['files']: + if file['metadata']['format'] == 'SafeTensor': + model_file = file + break + + if model_file is None: + # try to get any fp16 + for file in model_version['files']: + if file['metadata']['fp'] == 'fp16': + model_file = file + break + + if model_file is None: + # try to get any + for file in model_version['files']: + model_file = file + break + + if model_file is None: + raise ValueError(f"Could not find a model file to download for model id: {model_id}") + + return model_file, model_version['id'] + + +def get_model_path_from_url(url: str): + # get query params form url if they are set + # https: // civitai.com / models / 25694?modelVersionId = 127742 + query_params = {} + if '?' in url: + query_string = url.split('?')[1] + query_params = dict(qc.split("=") for qc in query_string.split("&")) + + # get model id from url + model_id = url.split('/')[-1] + # remove query params from model id + if '?' in model_id: + model_id = model_id.split('?')[0] + if model_id.isdigit(): + model_id = int(model_id) + else: + raise ValueError(f"Invalid model id: {model_id}") + + model_cache = ModelCache() + model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None)) + if model_path is not None: + return model_path + else: + # download model + file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None)) + + download_url = file_info['downloadUrl'] # url does not work directly + size_kb = file_info['sizeKB'] + filename = file_info['name'] + model_path = os.path.join(MODELS_PATH, filename) + + # download model + print(f"Did not find model locally, downloading from model from: {download_url}") + + # use tqdm to show status of downlod + response = requests.get(download_url, stream=True) + response.raise_for_status() + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}") + os.makedirs(os.path.dirname(model_path), exist_ok=True) + # remove tmp file if it exists + if os.path.exists(tmp_path): + os.remove(tmp_path) + + try: + + with open(tmp_path, 'wb') as f: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + f.write(data) + progress_bar.close() + # move to final path + os.rename(tmp_path, model_path) + model_cache.update_cache(model_id, model_version_id, model_path) + + return model_path + except Exception as e: + # remove tmp file + os.remove(tmp_path) + raise e + + +# if is main +if __name__ == '__main__': + model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742") + print(model_path) diff --git a/ai-toolkit/toolkit/clip_vision_adapter.py b/ai-toolkit/toolkit/clip_vision_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccc920caac68aa43a2b1ddc944079d88feb50a4 --- /dev/null +++ b/ai-toolkit/toolkit/clip_vision_adapter.py @@ -0,0 +1,406 @@ +from typing import TYPE_CHECKING, Mapping, Any + +import torch +import weakref + +from toolkit.config_modules import AdapterConfig +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule +from toolkit.prompt_utils import PromptEmbeds +from toolkit.train_tools import get_torch_dtype + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionModel +) + +from toolkit.resampler import Resampler + +import torch.nn as nn + + +class Embedder(nn.Module): + def __init__( + self, + num_input_tokens: int = 1, + input_dim: int = 1024, + num_output_tokens: int = 8, + output_dim: int = 768, + mid_dim: int = 1024 + ): + super(Embedder, self).__init__() + self.num_output_tokens = num_output_tokens + self.num_input_tokens = num_input_tokens + self.input_dim = input_dim + self.output_dim = output_dim + + self.layer_norm = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, mid_dim) + self.gelu = nn.GELU() + # self.fc2 = nn.Linear(mid_dim, mid_dim) + self.fc2 = nn.Linear(mid_dim, mid_dim) + + self.fc2.weight.data.zero_() + + self.layer_norm2 = nn.LayerNorm(mid_dim) + self.fc3 = nn.Linear(mid_dim, mid_dim) + self.gelu2 = nn.GELU() + self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens) + + # set the weights to 0 + self.fc3.weight.data.zero_() + self.fc4.weight.data.zero_() + + + # self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) + # self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) + + def forward(self, x): + if len(x.shape) == 2: + x = x.unsqueeze(1) + x = self.layer_norm(x) + x = self.fc1(x) + x = self.gelu(x) + x = self.fc2(x) + x = self.layer_norm2(x) + x = self.fc3(x) + x = self.gelu2(x) + x = self.fc4(x) + + x = x.view(-1, self.num_output_tokens, self.output_dim) + + return x + + +class ClipVisionAdapter(torch.nn.Module): + def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig): + super().__init__() + self.config = adapter_config + self.trigger = adapter_config.trigger + self.trigger_class_name = adapter_config.trigger_class_name + self.sd_ref: weakref.ref = weakref.ref(sd) + # embedding stuff + self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder] + self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer] + placeholder_tokens = [self.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.config.num_tokens): + additional_tokens.append(f"{self.trigger}_{i}") + placeholder_tokens += additional_tokens + + # handle dual tokenizer + self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [ + self.sd_ref().tokenizer] + self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [ + self.sd_ref().text_encoder] + + self.placeholder_token_ids = [] + self.embedding_tokens = [] + + print(f"Adding {placeholder_tokens} tokens to tokenizer") + print(f"Adding {self.config.num_tokens} tokens to tokenizer") + + + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.config.num_tokens: + raise ValueError( + f"The tokenizer already contains the token {self.trigger}. Please pass a different" + f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.config.num_tokens: + init_token_ids = init_token_ids[:self.config.num_tokens] + elif len(init_token_ids) < self.config.num_tokens: + pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids)) + + placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) + self.placeholder_token_ids.append(placeholder_token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # replace "[name] with this. on training. This is automatically generated in pipeline on inference + self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) + + # backup text encoder embeddings + self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] + + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() + self.device = self.sd_ref().unet.device + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + self.config.image_encoder_path, + ignore_mismatched_sizes=True + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if self.config.train_image_encoder: + self.image_encoder.train() + else: + self.image_encoder.eval() + + # max_seq_len = CLIP tokens + CLS token + image_encoder_state_dict = self.image_encoder.state_dict() + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if hasattr(self.image_encoder.config, 'hidden_sizes'): + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + else: + embedding_dim = self.image_encoder.config.target_hidden_size + + if self.config.clip_layer == 'image_embeds': + in_tokens = 1 + embedding_dim = self.image_encoder.config.projection_dim + + self.embedder = Embedder( + num_output_tokens=self.config.num_tokens, + num_input_tokens=in_tokens, + input_dim=embedding_dim, + output_dim=self.sd_ref().unet.config['cross_attention_dim'], + mid_dim=embedding_dim * self.config.num_tokens, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + + self.embedder.train() + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + state_dict = { + 'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + } + if self.config.train_image_encoder: + state_dict['image_encoder'] = self.image_encoder.state_dict( + *args, destination=destination, prefix=prefix, + keep_vars=keep_vars) + + return state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + self.embedder.load_state_dict(state_dict["embedder"], strict=strict) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + + def parameters(self, *args, **kwargs): + yield from self.embedder.parameters(*args, **kwargs) + + def named_parameters(self, *args, **kwargs): + yield from self.embedder.named_parameters(*args, **kwargs) + + def get_clip_image_embeds_from_tensors( + self, tensors_0_1: torch.Tensor, drop=False, + is_training=False, + has_been_preprocessed=False + ) -> torch.Tensor: + with torch.no_grad(): + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + # unconditional + if drop: + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + if drop: + # scale the noise down + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + with torch.set_grad_enabled(is_training): + if is_training: + self.image_encoder.train() + else: + self.image_encoder.eval() + clip_output = self.image_encoder(clip_image, output_hidden_states=True) + + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + clip_image_embeds = clip_output.image_embeds + return clip_image_embeds + + import torch + + def set_vec(self, new_vector, text_encoder_idx=0): + # Get the embedding layer + embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings() + + # Indices to replace in the embeddings + indices_to_replace = self.placeholder_token_ids[text_encoder_idx] + + # Replace the specified embeddings with new_vector + for idx in indices_to_replace: + vector_idx = idx - indices_to_replace[0] + embedding_layer.weight[idx] = new_vector[vector_idx] + + # adds it to the tokenizer + def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if clip_image_embeds.ndim == 2: + # expand the token dimension + clip_image_embeds = clip_image_embeds.unsqueeze(1) + image_prompt_embeds = self.embedder(clip_image_embeds) + # todo add support for multiple batch sizes + if image_prompt_embeds.shape[0] != 1: + raise ValueError("Batch size must be 1 for embedder for now") + + # output on sd1.5 is bs, num_tokens, 768 + if len(self.text_encoder_list) == 1: + # add it to the text encoder + self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) + elif len(self.text_encoder_list) == 2: + if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \ + image_prompt_embeds.shape[2]: + raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes") + # sdxl variants + # image_prompt_embeds = 2048 + # te1 = 768 + # te2 = 1280 + te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size] + te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:] + self.set_vec(te1_embeds[0], text_encoder_idx=0) + self.set_vec(te2_embeds[0], text_encoder_idx=1) + else: + + raise ValueError("Unsupported number of text encoders") + # just a place to put a breakpoint + pass + + def restore_embeddings(self): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip( + self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids + ): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ + min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False + with torch.no_grad(): + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds[index_no_updates] + # detach it all + text_encoder.get_input_embeddings().weight.detach_() + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True + + def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = ["[name]", "[trigger]"] + + replace_with = embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + + # reverses injection with class name. useful for normalizations + def inject_trigger_class_name_into_prompt(self, prompt): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + + default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger] + + replace_with = self.config.trigger_class_name + to_replace_list = default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt diff --git a/ai-toolkit/toolkit/config.py b/ai-toolkit/toolkit/config.py new file mode 100644 index 0000000000000000000000000000000000000000..52de47b836540c319e3ca8faa479312619139769 --- /dev/null +++ b/ai-toolkit/toolkit/config.py @@ -0,0 +1,110 @@ +import os +import json +from typing import Union + +import oyaml as yaml +import re +from collections import OrderedDict + +from toolkit.paths import TOOLKIT_ROOT + +possible_extensions = ['.json', '.jsonc', '.yaml', '.yml'] + + +def get_cwd_abs_path(path): + if not os.path.isabs(path): + path = os.path.join(os.getcwd(), path) + return path + + +def replace_env_vars_in_string(s: str) -> str: + """ + Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable. + If the environment variable is not set, raise an error. + """ + + def replacer(match): + var_name = match.group(1) + value = os.environ.get(var_name) + + if value is None: + raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.") + + return value + + return re.sub(r'\$\{([^}]+)\}', replacer, s) + + +def preprocess_config(config: OrderedDict, name: str = None): + if "job" not in config: + raise ValueError("config file must have a job key") + if "config" not in config: + raise ValueError("config file must have a config section") + if "name" not in config["config"] and name is None: + raise ValueError("config file must have a config.name key") + # we need to replace tags. For now just [name] + if name is None: + name = config["config"]["name"] + config_string = json.dumps(config) + config_string = config_string.replace("[name]", name) + config = json.loads(config_string, object_pairs_hook=OrderedDict) + return config + + +# Fixes issue where yaml doesnt load exponents correctly +fixed_loader = yaml.SafeLoader +fixed_loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + + +def get_config( + config_file_path_or_dict: Union[str, dict, OrderedDict], + name=None +): + # if we got a dict, process it and return it + if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict): + config = config_file_path_or_dict + return preprocess_config(config, name) + + config_file_path = config_file_path_or_dict + + # first check if it is in the config folder + config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) + # see if it is in the config folder with any of the possible extensions if it doesnt have one + real_config_path = None + if not os.path.exists(config_path): + for ext in possible_extensions: + if os.path.exists(config_path + ext): + real_config_path = config_path + ext + break + + # if we didn't find it there, check if it is a full path + if not real_config_path: + if os.path.exists(config_file_path): + real_config_path = config_file_path + elif os.path.exists(get_cwd_abs_path(config_file_path)): + real_config_path = get_cwd_abs_path(config_file_path) + + if not real_config_path: + raise ValueError(f"Could not find config file {config_file_path}") + + # if we found it, check if it is a json or yaml file + with open(real_config_path, 'r', encoding='utf-8') as f: + content = f.read() + content_with_env_replaced = replace_env_vars_in_string(content) + if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'): + config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict) + elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'): + config = yaml.load(content_with_env_replaced, Loader=fixed_loader) + else: + raise ValueError(f"Config file {config_file_path} must be a json or yaml file") + + return preprocess_config(config, name) diff --git a/ai-toolkit/toolkit/config_modules.py b/ai-toolkit/toolkit/config_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5698472dcf5d8ea0b6fdb9754bd5394c1fd22a --- /dev/null +++ b/ai-toolkit/toolkit/config_modules.py @@ -0,0 +1,1416 @@ +import os +import time +from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict +import random + +import torch +import torchaudio + +from toolkit.audio.album_artwork import add_album_artwork +from toolkit.prompt_utils import PromptEmbeds +from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH + +ImgExt = Literal['jpg', 'png', 'webp'] + +SaveFormat = Literal['safetensors', 'diffusers'] + +if TYPE_CHECKING: + from toolkit.guidance import GuidanceType + from toolkit.logging_aitk import EmptyLogger +else: + EmptyLogger = None + +class SaveConfig: + def __init__(self, **kwargs): + self.save_every: int = kwargs.get('save_every', 1000) + self.dtype: str = kwargs.get('dtype', 'float16') + self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5) + self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors') + if self.save_format not in ['safetensors', 'diffusers']: + raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}") + self.push_to_hub: bool = kwargs.get("push_to_hub", False) + self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None) + self.hf_private: Optional[str] = kwargs.get("hf_private", False) + +class LoggingConfig: + def __init__(self, **kwargs): + self.log_every: int = kwargs.get('log_every', 100) + self.verbose: bool = kwargs.get('verbose', False) + self.use_wandb: bool = kwargs.get('use_wandb', False) + self.use_ui_logger: bool = kwargs.get('use_ui_logger', False) + self.project_name: str = kwargs.get('project_name', 'ai-toolkit') + self.run_name: str = kwargs.get('run_name', None) + +class SampleItem: + def __init__( + self, + sample_config: 'SampleConfig', + **kwargs + ): + # prompt should always be in the kwargs + self.prompt = kwargs.get('prompt', None) + self.width: int = kwargs.get('width', sample_config.width) + self.height: int = kwargs.get('height', sample_config.height) + self.neg: str = kwargs.get('neg', sample_config.neg) + self.seed: Optional[int] = kwargs.get('seed', None) # if none, default to autogen seed + self.guidance_scale: float = kwargs.get('guidance_scale', sample_config.guidance_scale) + self.sample_steps: int = kwargs.get('sample_steps', sample_config.sample_steps) + self.fps: int = kwargs.get('fps', sample_config.fps) + self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames) + self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) + self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) + # for multi control image models + self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img) + self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None) + self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None) + + self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) + # convert to a number if it is a string + if isinstance(self.network_multiplier, str): + try: + self.network_multiplier = float(self.network_multiplier) + except: + print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0") + self.network_multiplier = 1.0 + + # only for models that support it, (qwen image edit 2509 for now) + self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False) + +class SampleConfig: + def __init__(self, **kwargs): + self.sampler: str = kwargs.get('sampler', 'ddpm') + self.sample_every: int = kwargs.get('sample_every', 100) + self.width: int = kwargs.get('width', 512) + self.height: int = kwargs.get('height', 512) + self.neg = kwargs.get('neg', False) + self.seed = kwargs.get('seed', 0) + self.walk_seed = kwargs.get('walk_seed', False) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + self.network_multiplier = kwargs.get('network_multiplier', 1) + self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext: ImgExt = kwargs.get('format', 'jpg') + self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) + self.refiner_start_at = kwargs.get('refiner_start_at', + 0.5) # step to start using refiner on sample if it exists + self.extra_values = kwargs.get('extra_values', []) + self.num_frames = kwargs.get('num_frames', 1) + self.fps: int = kwargs.get('fps', 16) + if self.num_frames > 1 and self.ext not in ['webp']: + print("Changing sample extention to animated webp") + self.ext = 'webp' + + prompts: list[str] = kwargs.get('prompts', []) + + self.samples: Optional[List[SampleItem]] = None + # use the legacy prompts if it is passed that way to get samples object + default_samples_kwargs = [ + {"prompt": x} for x in prompts + ] + raw_samples = kwargs.get('samples', default_samples_kwargs) + self.samples = [SampleItem(self, **item) for item in raw_samples] + # only for models that support it, (qwen image edit 2509 for now) + self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False) + + @property + def prompts(self): + # for backwards compatibility as this is checked for length frequently + return [sample.prompt for sample in self.samples if sample.prompt is not None] + + + + +class LormModuleSettingsConfig: + def __init__(self, **kwargs): + self.contains: str = kwargs.get('contains', '4nt$3') + self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + # min num parameters to attach to + self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) + self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) + + +class LoRMConfig: + def __init__(self, **kwargs): + self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + self.do_conv: bool = kwargs.get('do_conv', False) + self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) + self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) + module_settings = kwargs.get('module_settings', []) + default_module_settings = { + 'extract_mode': self.extract_mode, + 'extract_mode_param': self.extract_mode_param, + 'parameter_threshold': self.parameter_threshold, + } + module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings] + self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for + module_setting in module_settings] + + def get_config_for_module(self, block_name): + for setting in self.module_settings: + contain_pieces = setting.contains.split('|') + if all(contain_piece in block_name for contain_piece in contain_pieces): + return setting + # try replacing the . with _ + contain_pieces = setting.contains.replace('.', '_').split('|') + if all(contain_piece in block_name for contain_piece in contain_pieces): + return setting + # do default + return LormModuleSettingsConfig(**{ + 'extract_mode': self.extract_mode, + 'extract_mode_param': self.extract_mode_param, + 'parameter_threshold': self.parameter_threshold, + }) + + +NetworkType = Literal['lora', 'locon', 'lorm', 'lokr'] + + +class NetworkConfig: + def __init__(self, **kwargs): + self.type: NetworkType = kwargs.get('type', 'lora') + rank = kwargs.get('rank', None) + linear = kwargs.get('linear', None) + if rank is not None: + self.rank: int = rank # rank for backward compatibility + self.linear: int = rank + elif linear is not None: + self.rank: int = linear + self.linear: int = linear + else: + self.rank: int = 4 + self.linear: int = 4 + self.conv: int = kwargs.get('conv', None) + self.alpha: float = kwargs.get('alpha', 1.0) + self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) + self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) + self.dropout: Union[float, None] = kwargs.get('dropout', None) + self.network_kwargs: dict = kwargs.get('network_kwargs', {}) + + self.lorm_config: Union[LoRMConfig, None] = None + lorm = kwargs.get('lorm', None) + if lorm is not None: + self.lorm_config: LoRMConfig = LoRMConfig(**lorm) + + if self.type == 'lorm': + # set linear to arbitrary values so it makes them + self.linear = 4 + self.rank = 4 + if self.lorm_config.do_conv: + self.conv = 4 + + self.transformer_only = kwargs.get('transformer_only', True) + + self.lokr_full_rank = kwargs.get('lokr_full_rank', False) + if self.lokr_full_rank and self.type.lower() == 'lokr': + self.linear = 9999999999 + self.linear_alpha = 9999999999 + self.conv = 9999999999 + self.conv_alpha = 9999999999 + # -1 automatically finds the largest factor + self.lokr_factor = kwargs.get('lokr_factor', -1) + + # Use the old lokr format + self.old_lokr_format = kwargs.get('old_lokr_format', False) + + # for multi stage models + self.split_multistage_loras = kwargs.get('split_multistage_loras', True) + + # ramtorch, doesn't work yet + self.layer_offloading = kwargs.get('layer_offloading', False) + + # start from a pretrained lora + self.pretrained_lora_path = kwargs.get('pretrained_lora_path', None) + + # will create diffirential full weight modules for layers not conv/linear + # only useful in very special cases. + self.all_layers = kwargs.get('all_layers', False) + + +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] + +CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] + + +class AdapterConfig: + def __init__(self, **kwargs): + self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net, i2v + self.in_channels: int = kwargs.get('in_channels', 3) + self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) + self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) + self.downscale_factor: int = kwargs.get('downscale_factor', 8) + self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') + self.image_dir: str = kwargs.get('image_dir', None) + self.test_img_path: List[str] = kwargs.get('test_img_path', None) + if self.test_img_path is not None: + if isinstance(self.test_img_path, str): + self.test_img_path = self.test_img_path.split(',') + self.test_img_path = [p.strip() for p in self.test_img_path] + self.test_img_path = [p for p in self.test_img_path if p != ''] + + self.train: str = kwargs.get('train', False) + self.image_encoder_path: str = kwargs.get('image_encoder_path', None) + self.name_or_path = kwargs.get('name_or_path', None) + + num_tokens = kwargs.get('num_tokens', None) + if num_tokens is None and self.type.startswith('ip'): + if self.type == 'ip+': + num_tokens = 16 + num_tokens = 16 + elif self.type == 'ip': + num_tokens = 4 + + self.num_tokens: int = num_tokens + self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) + self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False) + if self.train_only_image_encoder: + self.train_image_encoder = True + self.train_only_image_encoder_positional_embedding: bool = kwargs.get( + 'train_only_image_encoder_positional_embedding', False) + self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe + self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) + self.safe_channels: int = kwargs.get('safe_channels', 2048) + self.safe_tokens: int = kwargs.get('safe_tokens', 8) + self.quad_image: bool = kwargs.get('quad_image', False) + + # clip vision + self.trigger = kwargs.get('trigger', 'tri993r') + self.trigger_class_name = kwargs.get('trigger_class_name', None) + + self.class_names = kwargs.get('class_names', []) + + self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None) + if self.clip_layer is None: + if self.type.startswith('ip+'): + self.clip_layer = 'penultimate_hidden_states' + else: + self.clip_layer = 'last_hidden_state' + + # text encoder + self.text_encoder_path: str = kwargs.get('text_encoder_path', None) + self.text_encoder_arch: str = kwargs.get('text_encoder_arch', 'clip') # clip t5 + + self.train_scaler: bool = kwargs.get('train_scaler', False) + self.scaler_lr: Optional[float] = kwargs.get('scaler_lr', None) + + # trains with a scaler to easy channel bias but merges it in on save + self.merge_scaler: bool = kwargs.get('merge_scaler', False) + + # for ilora + self.head_dim: int = kwargs.get('head_dim', 1024) + self.num_heads: int = kwargs.get('num_heads', 1) + self.ilora_down: bool = kwargs.get('ilora_down', True) + self.ilora_mid: bool = kwargs.get('ilora_mid', True) + self.ilora_up: bool = kwargs.get('ilora_up', True) + + self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512) + self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False) + + self.flux_only_double: bool = kwargs.get('flux_only_double', False) + + # train and use a conv layer to pool the embedding + self.conv_pooling: bool = kwargs.get('conv_pooling', False) + self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1) + self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None) + + # for llm adapter + self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0) + self.quantize_llm: bool = kwargs.get('quantize_llm', False) + + # for control lora only + lora_config: dict = kwargs.get('lora_config', None) + if lora_config is not None: + self.lora_config: NetworkConfig = NetworkConfig(**lora_config) + else: + self.lora_config = None + self.num_control_images: int = kwargs.get('num_control_images', 1) + # decimal for how often the control is dropped out and replaced with noise 1.0 is 100% + self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0) + self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False) + self.invert_inpaint_mask_chance: float = kwargs.get('invert_inpaint_mask_chance', 0.0) + + # for subpixel adapter + self.subpixel_downscale_factor: int = kwargs.get('subpixel_downscale_factor', 8) + + # for i2v adapter + # append the masked start frame. During pretraining we will only do the vision encoder + self.i2v_do_start_frame: bool = kwargs.get('i2v_do_start_frame', False) + + +class EmbeddingConfig: + def __init__(self, **kwargs): + self.trigger = kwargs.get('trigger', 'custom_embedding') + self.tokens = kwargs.get('tokens', 4) + self.init_words = kwargs.get('init_words', '*') + self.save_format = kwargs.get('save_format', 'safetensors') + self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior + + +class DecoratorConfig: + def __init__(self, **kwargs): + self.num_tokens: str = kwargs.get('num_tokens', 4) + + +ContentOrStyleType = Literal['balanced', 'style', 'content'] +LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] + + +class TrainConfig: + def __init__(self, **kwargs): + self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') + self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.steps: int = kwargs.get('steps', 1000) + self.lr = kwargs.get('lr', 1e-6) + self.unet_lr = kwargs.get('unet_lr', self.lr) + self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr) + self.refiner_lr = kwargs.get('refiner_lr', self.lr) + self.embedding_lr = kwargs.get('embedding_lr', self.lr) + self.adapter_lr = kwargs.get('adapter_lr', self.lr) + self.optimizer = kwargs.get('optimizer', 'adamw') + self.optimizer_params = kwargs.get('optimizer_params', {}) + self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') + self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) + self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) + self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999) + self.batch_size: int = kwargs.get('batch_size', 1) + self.orig_batch_size: int = self.batch_size + self.dtype: str = kwargs.get('dtype', 'fp32') + self.xformers = kwargs.get('xformers', False) + self.sdp = kwargs.get('sdp', False) + # see https://huggingface.co/docs/diffusers/main/optimization/attention_backends#available-backends for options + self.attention_backend: str = kwargs.get('attention_backend', 'native') # native, flash, _flash_3_hub, _flash_3, + self.train_unet = kwargs.get('train_unet', True) + self.train_text_encoder = kwargs.get('train_text_encoder', False) + self.train_refiner = kwargs.get('train_refiner', True) + self.train_turbo = kwargs.get('train_turbo', False) + self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False) + self.min_snr_gamma = kwargs.get('min_snr_gamma', None) + self.snr_gamma = kwargs.get('snr_gamma', None) + # trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials + # this should balance the learning rate across all timesteps over time + self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False) + self.noise_offset = kwargs.get('noise_offset', 0.0) + self.skip_first_sample = kwargs.get('skip_first_sample', False) + self.force_first_sample = kwargs.get('force_first_sample', False) + self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) + self.weight_jitter = kwargs.get('weight_jitter', 0.0) + self.merge_network_on_save = kwargs.get('merge_network_on_save', False) + self.merge_network_on_save_strength = kwargs.get('merge_network_on_save_strength', 1.0) + self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) + self.start_step = kwargs.get('start_step', None) + self.free_u = kwargs.get('free_u', False) + self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) + self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net + self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) + self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0) + self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0) + self.do_signal_correction_noise = kwargs.get('do_signal_correction_noise', False) + # batch noise correction adds other images in the batch as noise to correct away from other images + self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False) + self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1) + self.do_signal_amplification = kwargs.get('do_signal_amplification', False) + self.signal_amplification_strength = kwargs.get('signal_amplification_strength', 0.5) + + self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0) + self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) + self.img_multiplier = kwargs.get('img_multiplier', 1.0) + self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) + self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) + self.negative_prompt = kwargs.get('negative_prompt', None) + self.max_negative_prompts = kwargs.get('max_negative_prompts', 1) + # multiplier applied to loos on regularization images + self.reg_weight = kwargs.get('reg_weight', 1.0) + self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) + # automatically adapte the vae scaling based on the image norm + self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False) + + # dropout that happens before encoding. It functions independently per text encoder + self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) + + # match the norm of the noise before computing loss. This will help the model maintain its + # current understandin of the brightness of images. + + self.match_noise_norm = kwargs.get('match_noise_norm', False) + + # set to -1 to accumulate gradients for entire epoch + # warning, only do this with a small dataset or you will run out of memory + # This is legacy but left in for backwards compatibility + self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) + + # this will do proper gradient accumulation where you will not see a step until the end of the accumulation + # the method above will show a step every accumulation + self.gradient_accumulation = kwargs.get('gradient_accumulation', 1) + if self.gradient_accumulation > 1: + if self.gradient_accumulation_steps != 1: + raise ValueError("gradient_accumulation and gradient_accumulation_steps are mutually exclusive") + + # short long captions will double your batch size. This only works when a dataset is + # prepared with a json caption file that has both short and long captions in it. It will + # Double up every image and run it through with both short and long captions. The idea + # is that the network will learn how to generate good images with both short and long captions + self.short_and_long_captions = kwargs.get('short_and_long_captions', False) + # if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only + self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False) + + # basically gradient accumulation but we run just 1 item through the network + # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful + # for training tricks that increase batch size but need a single gradient step + self.single_item_batching = kwargs.get('single_item_batching', False) + + match_adapter_assist = kwargs.get('match_adapter_assist', False) + self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) + self.loss_target: LossTarget = kwargs.get('loss_target', + 'noise') # noise, source, unaugmented, differential_noise + + # When a mask is passed in a dataset, and this is true, + # we will predict noise without a the LoRa network and use the prediction as a target for + # unmasked reign. It is unmasked regularization basically + self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False) + self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5) + + # DOP will will run the same image and prompt through the network without the trigger word blank and use it as a target + self.diff_output_preservation = kwargs.get('diff_output_preservation', False) + self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0) + # If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman" + self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '') + + # blank prompt preservation will preserve the model's knowledge of a blank prompt + self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False) + self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0) + + # legacy + if match_adapter_assist and self.match_adapter_chance == 0.0: + self.match_adapter_chance = 1.0 + + # standardize inputs to the meand std of the model knowledge + self.standardize_images = kwargs.get('standardize_images', False) + self.standardize_latents = kwargs.get('standardize_latents', False) + + # if self.train_turbo and not self.noise_scheduler.startswith("euler"): + # raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers") + + self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) + self.do_cfg = kwargs.get('do_cfg', False) + self.do_random_cfg = kwargs.get('do_random_cfg', False) + self.cfg_scale = kwargs.get('cfg_scale', 1.0) + self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale) + self.cfg_rescale = kwargs.get('cfg_rescale', None) + if self.cfg_rescale is None: + self.cfg_rescale = self.cfg_scale + + # applies the inverse of the prediction mean and std to the target to correct + # for norm drift + self.correct_pred_norm = kwargs.get('correct_pred_norm', False) + self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) + + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow, pseudo_huber + + # do the loss on a timestep to 0 prediction + self.t0_loss_target = kwargs.get('t0_loss_target', False) + self.t0_velocity_equiv_weight = kwargs.get('t0_velocity_equiv_weight', False) + + # do additional fft loss + self.do_fft_loss = kwargs.get('do_fft_loss', False) + self.do_fft_velocity_equiv_weight = kwargs.get('do_fft_velocity_equiv_weight', False) + + # scale the prediction by this. Increase for more detail, decrease for less + self.pred_scaler = kwargs.get('pred_scaler', 1.0) + + # repeats the prompt a few times to saturate the encoder + self.prompt_saturation_chance = kwargs.get('prompt_saturation_chance', 0.0) + + # applies negative loss on the prior to encourage network to diverge from it + self.do_prior_divergence = kwargs.get('do_prior_divergence', False) + + ema_config: Union[Dict, None] = kwargs.get('ema_config', None) + # if it is set explicitly to false, leave it false. + if ema_config is not None and ema_config.get('use_ema', False): + ema_config['use_ema'] = True + print(f"Using EMA") + else: + ema_config = {'use_ema': False} + + self.ema_config: EMAConfig = EMAConfig(**ema_config) + + # adds an additional loss to the network to encourage it output a normalized standard deviation + self.target_norm_std = kwargs.get('target_norm_std', None) + self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample, weighted, one_step + self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8) + self.linear_timesteps = kwargs.get('linear_timesteps', False) + self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) + self.disable_sampling = kwargs.get('disable_sampling', False) + + # will cache a blank prompt or the trigger word, and unload the text encoder to cpu + # will make training faster and use less vram + self.unload_text_encoder = kwargs.get('unload_text_encoder', False) + # will toggle all datasets to cache text embeddings + self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False) + # for swapping which parameters are trained during training + self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False) + # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more + self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1) + # bypass the guidance embedding for training. For open flux with guidance embedding + self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False) + + # diffusion feature extractor + self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None) + self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0) + + # we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture + self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path) + self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight) + + # optimal noise pairing + self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) + + # forces same noise for the same image at a given size. + self.force_consistent_noise = kwargs.get('force_consistent_noise', False) + self.blended_blur_noise = kwargs.get('blended_blur_noise', False) + + # contrastive loss + self.do_guidance_loss = kwargs.get('do_guidance_loss', False) + self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0) + self.do_guidance_loss_cfg_zero: bool = kwargs.get('do_guidance_loss_cfg_zero', False) + self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') + if isinstance(self.guidance_loss_target, tuple): + self.guidance_loss_target = list(self.guidance_loss_target) + + self.do_differential_guidance = kwargs.get('do_differential_guidance', False) + self.differential_guidance_scale = kwargs.get('differential_guidance_scale', 3.0) + + # for multi stage models, how often to switch the boundary + self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1) + + # stabilizes empty prompts to be zeroed predictions + self.do_blank_stabilization = kwargs.get('do_blank_stabilization', False) + + self.audio_loss_multiplier = kwargs.get("audio_loss_multiplier", 1.0) + + # will throw detailed error when it goes over + self.max_loss_debug: bool = kwargs.get("max_loss_debug", False) + # will clip the loss to this amount to prevent wild outliers + self.max_loss: Optional[float] = kwargs.get("max_loss", None) + + +ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] + + +class ModelConfig: + def __init__(self, **kwargs): + self.name_or_path: str = kwargs.get('name_or_path', None) + # name or path is updated on fine tuning. Keep a copy of the original + self.name_or_path_original: str = self.name_or_path + self.is_v2: bool = kwargs.get('is_v2', False) + self.is_xl: bool = kwargs.get('is_xl', False) + self.is_pixart: bool = kwargs.get('is_pixart', False) + self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False) + self.is_auraflow: bool = kwargs.get('is_auraflow', False) + self.is_v3: bool = kwargs.get('is_v3', False) + self.is_flux: bool = kwargs.get('is_flux', False) + self.is_lumina2: bool = kwargs.get('is_lumina2', False) + if self.is_pixart_sigma: + self.is_pixart = True + self.use_flux_cfg = kwargs.get('use_flux_cfg', False) + self.is_ssd: bool = kwargs.get('is_ssd', False) + self.is_vega: bool = kwargs.get('is_vega', False) + self.is_v_pred: bool = kwargs.get('is_v_pred', False) + self.dtype: str = kwargs.get('dtype', 'float16') + self.vae_path = kwargs.get('vae_path', None) + self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None) + self._original_refiner_name_or_path = self.refiner_name_or_path + self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) + self.lora_path = kwargs.get('lora_path', None) + # mainly for decompression loras for distilled models + self.assistant_lora_path = kwargs.get('assistant_lora_path', None) + self.inference_lora_path = kwargs.get('inference_lora_path', None) + # a lora that stays inactive except during the unconditional (negative) + # CFG pass -- used to learn the unconditional branch without a second model + self.unconditional_lora_path = kwargs.get('unconditional_lora_path', None) + self.latent_space_version = kwargs.get('latent_space_version', None) + + # only for SDXL models for now + self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) + self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True) + + self.experimental_xl: bool = kwargs.get('experimental_xl', False) + + if self.name_or_path is None: + raise ValueError('name_or_path must be specified') + + if self.is_ssd: + # sed sdxl as true since it is mostly the same architecture + self.is_xl = True + + if self.is_vega: + self.is_xl = True + + # for text encoder quant. Only works with pixart currently + self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4 + self.unet_path = kwargs.get("unet_path", None) + self.unet_sample_size = kwargs.get("unet_sample_size", None) + self.vae_device = kwargs.get("vae_device", None) + self.vae_dtype = kwargs.get("vae_dtype", self.dtype) + self.te_device = kwargs.get("te_device", None) + self.te_dtype = kwargs.get("te_dtype", self.dtype) + + # only for flux for now + self.quantize = kwargs.get("quantize", False) + self.quantize_te = kwargs.get("quantize_te", self.quantize) + self.qtype = kwargs.get("qtype", "qfloat8") + self.qtype_te = kwargs.get("qtype_te", "qfloat8") + self.low_vram = kwargs.get("low_vram", False) + self.attn_masking = kwargs.get("attn_masking", False) + if self.attn_masking and not self.is_flux: + raise ValueError("attn_masking is only supported with flux models currently") + # for targeting a specific layers + self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None) + self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None) + self.quantize_kwargs = kwargs.get("quantize_kwargs", {}) + + # splits the model over the available gpus WIP + self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False) + if self.split_model_over_gpus and not self.is_flux: + raise ValueError("split_model_over_gpus is only supported with flux models currently") + self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) + + self.te_name_or_path = kwargs.get("te_name_or_path", None) + + self.arch: ModelArch = kwargs.get("arch", None) + + # auto memory management, only for some models + self.auto_memory = kwargs.get("auto_memory", False) + # auto memory is deprecated, use layer offloading instead + if self.auto_memory: + print("auto_memory is deprecated, use layer_offloading instead") + self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory ) + if self.layer_offloading and self.qtype == "qfloat8": + self.qtype = "float8" + if self.layer_offloading and self.qtype_te == "qfloat8": + self.qtype_te = "float8" + + # Mac mps only works with torachao uint + if torch.backends.mps.is_available() and self.qtype == "qfloat8": + self.qtype = "int8" + if torch.backends.mps.is_available() and self.qtype_te == "qfloat8": + self.qtype_te = "int8" + + # 0 is off and 1.0 is 100% of the layers + self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0) + self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0) + + # can be used to load the extras like text encoder or vae from here + # only setup for some models but will prevent having to download the te for + # 20 different model variants + self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path) + + # path to an accuracy recovery adapter, either local or remote + self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None) + + # parse ARA from qtype + if self.qtype is not None and "|" in self.qtype: + self.qtype, self.accuracy_recovery_adapter = self.qtype.split('|') + + # compile the model with torch compile + self.compile = kwargs.get("compile", False) + + if self.compile and self.quantize: + print("Quantized model detected - allowing torch.compile (experimental)") + self.block_compile = kwargs.get("block_compile", False) + self.compile_mode = kwargs.get("compile_mode", "default") + self.compile_fullgraph = kwargs.get("compile_fullgraph", False) + self.compile_dynamic = kwargs.get("compile_dynamic", True) + self.cache_size_limit = kwargs.get("cache_size_limit", None) + + # kwargs to pass to the model + self.model_kwargs = kwargs.get("model_kwargs", {}) + + # model paths for models that support it + self.model_paths = kwargs.get("model_paths", {}) + + self.in_context = kwargs.get("in_context", False) + + # allow frontend to pass arch with a color like arch:tag + # but remove the tag + if self.arch is not None: + if ':' in self.arch: + self.arch = self.arch.split(':')[0] + + if self.arch == "flex1": + self.arch = "flux" + + + # handle migrating to new model arch + if self.arch is not None: + # reverse the arch to the old style + if self.arch == 'sd2': + self.is_v2 = True + elif self.arch == 'sd3': + self.is_v3 = True + elif self.arch == 'sdxl': + self.is_xl = True + elif self.arch == 'pixart': + self.is_pixart = True + elif self.arch == 'pixart_sigma': + self.is_pixart_sigma = True + elif self.arch == 'auraflow': + self.is_auraflow = True + elif self.arch == 'flux': + self.is_flux = True + elif self.arch == 'lumina2': + self.is_lumina2 = True + elif self.arch == 'vega': + self.is_vega = True + elif self.arch == 'ssd': + self.is_ssd = True + else: + pass + if self.arch is None: + if kwargs.get('is_v2', False): + self.arch = 'sd2' + elif kwargs.get('is_v3', False): + self.arch = 'sd3' + elif kwargs.get('is_xl', False): + self.arch = 'sdxl' + elif kwargs.get('is_pixart', False): + self.arch = 'pixart' + elif kwargs.get('is_pixart_sigma', False): + self.arch = 'pixart_sigma' + elif kwargs.get('is_auraflow', False): + self.arch = 'auraflow' + elif kwargs.get('is_flux', False): + self.arch = 'flux' + elif kwargs.get('is_lumina2', False): + self.arch = 'lumina2' + elif kwargs.get('is_vega', False): + self.arch = 'vega' + elif kwargs.get('is_ssd', False): + self.arch = 'ssd' + else: + self.arch = 'sd1' + + + +class EMAConfig: + def __init__(self, **kwargs): + self.use_ema: bool = kwargs.get('use_ema', False) + self.ema_decay: float = kwargs.get('ema_decay', 0.999) + # feeds back the decay difference into the parameter + self.use_feedback: bool = kwargs.get('use_feedback', False) + + # every update, the params are multiplied by this amount + # only use for things without a bias like lora + # similar to a decay in an optimizer but the opposite + self.param_multiplier: float = kwargs.get('param_multiplier', 1.0) + + +class ReferenceDatasetConfig: + def __init__(self, **kwargs): + # can pass with a side by side pait or a folder with pos and neg folder + self.pair_folder: str = kwargs.get('pair_folder', None) + self.pos_folder: str = kwargs.get('pos_folder', None) + self.neg_folder: str = kwargs.get('neg_folder', None) + + self.network_weight: float = float(kwargs.get('network_weight', 1.0)) + self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight)) + self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight)) + # make sure they are all absolute values no negatives + self.pos_weight = abs(self.pos_weight) + self.neg_weight = abs(self.neg_weight) + + self.target_class: str = kwargs.get('target_class', '') + self.size: int = kwargs.get('size', 512) + + +class SliderTargetConfig: + def __init__(self, **kwargs): + self.target_class: str = kwargs.get('target_class', '') + self.positive: str = kwargs.get('positive', '') + self.negative: str = kwargs.get('negative', '') + self.multiplier: float = kwargs.get('multiplier', 1.0) + self.weight: float = kwargs.get('weight', 1.0) + self.shuffle: bool = kwargs.get('shuffle', False) + + +class GuidanceConfig: + def __init__(self, **kwargs): + self.target_class: str = kwargs.get('target_class', '') + self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) + self.positive_prompt: str = kwargs.get('positive_prompt', '') + self.negative_prompt: str = kwargs.get('negative_prompt', '') + + +class SliderConfigAnchors: + def __init__(self, **kwargs): + self.prompt = kwargs.get('prompt', '') + self.neg_prompt = kwargs.get('neg_prompt', '') + self.multiplier = kwargs.get('multiplier', 1.0) + + +class SliderConfig: + def __init__(self, **kwargs): + targets = kwargs.get('targets', []) + anchors = kwargs.get('anchors', []) + anchors = [SliderConfigAnchors(**anchor) for anchor in anchors] + self.anchors: List[SliderConfigAnchors] = anchors + self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) + self.prompt_file: str = kwargs.get('prompt_file', None) + self.prompt_tensors: str = kwargs.get('prompt_tensors', None) + self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) + self.use_adapter: bool = kwargs.get('use_adapter', None) # depth + self.adapter_img_dir = kwargs.get('adapter_img_dir', None) + self.low_ram = kwargs.get('low_ram', False) + + # expand targets if shuffling + from toolkit.prompt_utils import get_slider_target_permutations + self.targets: List[SliderTargetConfig] = [] + targets = [SliderTargetConfig(**target) for target in targets] + # do permutations if shuffle is true + print(f"Building slider targets") + for target in targets: + if target.shuffle: + target_permutations = get_slider_target_permutations(target, max_permutations=8) + self.targets = self.targets + target_permutations + else: + self.targets.append(target) + print(f"Built {len(self.targets)} slider targets (with permutations)") + +ControlTypes = Literal['depth', 'line', 'pose', 'inpaint', 'mask', 'sapiens2_mask'] + +class DatasetConfig: + """ + Dataset config for sd-datasets + + """ + + def __init__(self, **kwargs): + self.type = kwargs.get('type', 'image') # sd, slider, reference + # will be legacy + self.folder_path: str = kwargs.get('folder_path', None) + # can be json or folder path + self.dataset_path: str = kwargs.get('dataset_path', None) + + self.default_caption: str = kwargs.get('default_caption', None) + # trigger word for just this dataset + self.trigger_word: str = kwargs.get('trigger_word', None) + random_triggers = kwargs.get('random_triggers', []) + # if they are a string, load them from a file + if isinstance(random_triggers, str) and os.path.exists(random_triggers): + with open(random_triggers, 'r') as f: + random_triggers = f.read().splitlines() + # remove empty lines + random_triggers = [line for line in random_triggers if line.strip() != ''] + self.random_triggers: List[str] = random_triggers + self.random_triggers_max: int = kwargs.get('random_triggers_max', 1) + self.caption_ext: str = kwargs.get('caption_ext', '.txt') + # if caption_ext doesnt start with a dot, add it + if self.caption_ext and not self.caption_ext.startswith('.'): + self.caption_ext = '.' + self.caption_ext + self.random_scale: bool = kwargs.get('random_scale', False) + self.random_crop: bool = kwargs.get('random_crop', False) + self.resolution: int = kwargs.get('resolution', 512) + self.scale: float = kwargs.get('scale', 1.0) + self.buckets: bool = kwargs.get('buckets', True) + self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) + self.is_reg: bool = kwargs.get('is_reg', False) + self.prior_reg: bool = kwargs.get('prior_reg', False) + self.network_weight: float = float(kwargs.get('network_weight', 1.0)) + self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) + self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) + self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped + self.flip_x: bool = kwargs.get('flip_x', False) + self.flip_y: bool = kwargs.get('flip_y', False) + self.augments: List[str] = kwargs.get('augments', []) + self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc + if self.control_path == '': + self.control_path = None + + # handle multi control inputs from the ui. It is just easier to handle it here for a cleaner ui experience + control_path_1 = kwargs.get('control_path_1', None) + control_path_2 = kwargs.get('control_path_2', None) + control_path_3 = kwargs.get('control_path_3', None) + + if any([control_path_1, control_path_2, control_path_3]): + control_paths = [] + if control_path_1: + control_paths.append(control_path_1) + if control_path_2: + control_paths.append(control_path_2) + if control_path_3: + control_paths.append(control_path_3) + self.control_path = control_paths + + # color for transparent reigon of control images with transparency + self.control_transparent_color: List[int] = kwargs.get('control_transparent_color', [0, 0, 0]) + # inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will + # be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored + self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None) + # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters) + self.full_size_control_images: bool = kwargs.get('full_size_control_images', True) + self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask + self.mask_path: str = kwargs.get('mask_path', + None) # focus mask (black and white. White has higher loss than black) + self.unconditional_path: str = kwargs.get('unconditional_path', + None) # path where matching unconditional images are located + self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask + self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1 + self.poi: Union[str, None] = kwargs.get('poi', None) + if self.poi is not None: + raise ValueError("poi is deprecated and is no longer supported") + self.use_short_captions: bool = kwargs.get('use_short_captions', False) # if true, will use 'caption_short' from json + self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset + # cache latents will store them in memory + self.cache_latents: bool = kwargs.get('cache_latents', False) + # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory + self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) + self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False) + self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False) + self.load_image_when_caching_latents: bool = kwargs.get('load_image_when_caching_latents', False) + + self.standardize_images: bool = kwargs.get('standardize_images', False) + + # https://albumentations.ai/docs/api_reference/augmentations/transforms + # augmentations are returned as a separate image and cannot currently be cached + self.augmentations: List[dict] = kwargs.get('augmentations', None) + self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False) + + has_augmentations = self.augmentations is not None and len(self.augmentations) > 0 + + if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk): + print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False") + self.cache_latents = False + self.cache_latents_to_disk = False + + # legacy compatability + legacy_caption_type = kwargs.get('caption_type', None) + if legacy_caption_type: + self.caption_ext = legacy_caption_type + self.caption_type = self.caption_ext + self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted') + + # ip adapter / reference dataset + self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc + # get the clip image randomly from the same folder as the image. Useful for folder grouped pairs. + self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False) + self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None) + self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False) + self.replacements: List[str] = kwargs.get('replacements', []) + self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0) + + self.num_workers: int = kwargs.get('num_workers', 2) + self.prefetch_factor: int = kwargs.get('prefetch_factor', 2) + self.extra_values: List[float] = kwargs.get('extra_values', []) + self.square_crop: bool = kwargs.get('square_crop', False) + # apply same augmentations to control images. Usually want this true unless special case + self.replay_transforms: bool = kwargs.get('replay_transforms', True) + + # for video + # if num_frames is greater than 1, the dataloader will look for video files. + # num_frames will be the number of frames in the training batch. If num_frames is 1, it will look for images + self.num_frames: int = kwargs.get('num_frames', 1) + # if true, will shrink video to our frames. For instance, if we have a video with 100 frames and num_frames is 10, + # we would pull frame 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 so they are evenly spaced + self.shrink_video_to_frames: bool = kwargs.get('shrink_video_to_frames', True) + # fps is only used if shrink_video_to_frames is false. This will attempt to pull the num_frames at the given fps + # it will select a random start frame and pull the frames at the given fps + # this could have various issues with shorter videos and videos with variable fps + # I recommend trimming your videos to the desired length and using shrink_video_to_frames(default) + self.fps: int = kwargs.get('fps', 24) + + # auto_frame_count pull as many frames as in the video at given fps + # Important, make sure fps for dataset is set correctly. + # this wont work with bucketing for now until I can handle this before bucketing. + self.auto_frame_count: bool = kwargs.get('auto_frame_count', False) + + # debug the frame count and frame selection. You dont need this. It is for debugging. + self.debug: bool = kwargs.get('debug', False) + + # automatic controls + self.controls: List[ControlTypes] = kwargs.get('controls', []) + if isinstance(self.controls, str): + self.controls = [self.controls] + # remove empty strings + self.controls = [control for control in self.controls if control.strip() != ''] + + # if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing + self.fast_image_size: bool = kwargs.get('fast_image_size', False) + + self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable + self.do_audio: bool = kwargs.get('do_audio', False) # load audio from video files for models that support it + self.audio_preserve_pitch: bool = kwargs.get('audio_preserve_pitch', False) # preserve pitch when stretching audio to fit num_frames + self.audio_normalize: bool = kwargs.get('audio_normalize', False) # normalize audio volume levels when loading + + +def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: + """ + This just splits up the datasets by resolutions so you dont have to do it manually + :param raw_config: + :return: + """ + # split up datasets by resolutions + new_config = [] + for dataset in raw_config: + resolution = dataset.get('resolution', 512) + if isinstance(resolution, list): + resolution_list = resolution + else: + resolution_list = [resolution] + for res in resolution_list: + dataset_copy = dataset.copy() + dataset_copy['resolution'] = res + new_config.append(dataset_copy) + return new_config + + +class GenerateImageConfig: + def __init__( + self, + prompt: str = '', + prompt_2: Optional[str] = None, + width: int = 512, + height: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str = '', + negative_prompt_2: Optional[str] = None, + seed: int = -1, + network_multiplier: float = 1.0, + guidance_rescale: float = 0.0, + # the tag [time] will be replaced with milliseconds since epoch + output_path: str = None, # full image path + output_folder: str = None, # folder to save image in if output_path is not specified + output_ext: str = ImgExt, # extension to save image as if output_path is not specified + output_tail: str = '', # tail to add to output filename + add_prompt_file: bool = False, # add a prompt file with generated image + adapter_image_path: str = None, # path to adapter image + adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning + latents: Union[torch.Tensor | None] = None, # input latent to start with, + extra_kwargs: dict = None, # extra data to save with prompt file + refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end + extra_values: List[float] = None, # extra values to save with prompt file + logger: Optional[EmptyLogger] = None, + ctrl_img: Optional[str] = None, # control image for controlnet + ctrl_img_1: Optional[str] = None, # first control image for multi control model + ctrl_img_2: Optional[str] = None, # second control image for multi control model + ctrl_img_3: Optional[str] = None, # third control image for multi control model + num_frames: int = 1, + fps: int = 15, + ctrl_idx: int = 0, + do_cfg_norm: bool = False, + ): + self.width: int = width + self.height: int = height + self.num_inference_steps: int = num_inference_steps + self.guidance_scale: float = guidance_scale + self.guidance_rescale: float = guidance_rescale + self.prompt: str = prompt + self.prompt_2: str = prompt_2 + self.negative_prompt: str = negative_prompt + self.negative_prompt_2: str = negative_prompt_2 + self.latents: Union[torch.Tensor | None] = latents + + self.output_path: str = output_path + self.seed: int = seed + if self.seed == -1: + # generate random one + self.seed = random.randint(0, 2 ** 32 - 1) + self.network_multiplier: float = network_multiplier + self.output_folder: str = output_folder + self.output_ext: str = output_ext + self.add_prompt_file: bool = add_prompt_file + self.output_tail: str = output_tail + self.gen_time: int = int(time.time() * 1000) + self.adapter_image_path: str = adapter_image_path + self.adapter_conditioning_scale: float = adapter_conditioning_scale + self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {} + self.refiner_start_at = refiner_start_at + self.extra_values = extra_values if extra_values is not None else [] + self.num_frames = num_frames + self.fps = fps + self.ctrl_img = ctrl_img + self.ctrl_idx = ctrl_idx + + if ctrl_img_1 is None and ctrl_img is not None: + ctrl_img_1 = ctrl_img + + self.ctrl_img_1 = ctrl_img_1 + self.ctrl_img_2 = ctrl_img_2 + self.ctrl_img_3 = ctrl_img_3 + + # prompt string will override any settings above + self._process_prompt_string() + + # handle dual text encoder prompts if nothing passed + if negative_prompt_2 is None: + self.negative_prompt_2 = negative_prompt + + if prompt_2 is None: + self.prompt_2 = self.prompt + + # parse prompt paths + if self.output_path is None and self.output_folder is None: + raise ValueError('output_path or output_folder must be specified') + elif self.output_path is not None: + self.output_folder = os.path.dirname(self.output_path) + self.output_ext = os.path.splitext(self.output_path)[1][1:] + self.output_filename_no_ext = os.path.splitext(os.path.basename(self.output_path))[0] + + else: + self.output_filename_no_ext = '[time]_[count]' + if len(self.output_tail) > 0: + self.output_filename_no_ext += '_' + self.output_tail + self.output_path = os.path.join(self.output_folder, self.output_filename_no_ext + '.' + self.output_ext) + + # adjust height + self.height = max(64, self.height - self.height % 8) # round to divisible by 8 + self.width = max(64, self.width - self.width % 8) # round to divisible by 8 + + self.logger = logger + + self.do_cfg_norm: bool = do_cfg_norm + + def set_gen_time(self, gen_time: int = None): + if gen_time is not None: + self.gen_time = gen_time + else: + self.gen_time = int(time.time() * 1000) + + def _get_path_no_ext(self, count: int = 0, max_count=0): + # zero pad count + count_str = str(count).zfill(len(str(max_count))) + # replace [time] with gen time + filename = self.output_filename_no_ext.replace('[time]', str(self.gen_time)) + # replace [count] with count + filename = filename.replace('[count]', count_str) + return filename + + def get_image_path(self, count: int = 0, max_count=0): + filename = self._get_path_no_ext(count, max_count) + ext = self.output_ext + # if it does not start with a dot add one + if ext[0] != '.': + ext = '.' + ext + filename += ext + # join with folder + return os.path.join(self.output_folder, filename) + + def get_prompt_path(self, count: int = 0, max_count=0): + filename = self._get_path_no_ext(count, max_count) + filename += '.txt' + # join with folder + return os.path.join(self.output_folder, filename) + + def save_image(self, image, count: int = 0, max_count=0): + # make parent dirs + os.makedirs(self.output_folder, exist_ok=True) + self.set_gen_time() + if isinstance(image, list): + # video + if self.num_frames == 1: + raise ValueError(f"Expected 1 img but got a list {len(image)}") + if self.num_frames > 1 and self.output_ext not in ['webp']: + self.output_ext = 'webp' + if self.output_ext == 'webp': + # save as animated webp + duration = 1000 // self.fps # Convert fps to milliseconds per frame + image[0].save( + self.get_image_path(count, max_count), + format='WEBP', + append_images=image[1:], + save_all=True, + duration=duration, # Duration per frame in milliseconds + loop=0, # 0 means loop forever + quality=80 # Quality setting (0-100) + ) + else: + raise ValueError(f"Unsupported video format {self.output_ext}") + elif self.output_ext in ['wav', 'mp3', 'flac', 'ogg']: + # save audio file + audio_path = self.get_image_path(count, max_count) + torchaudio.save( + audio_path, + image[0].to('cpu'), + sample_rate=48000, + format=None, + backend=None + ) + if self.output_ext == 'mp3': + add_album_artwork(audio_path) + else: + # TODO save image gen header info for A1111 and us, our seeds probably wont match + image.save(self.get_image_path(count, max_count)) + # do prompt file + if self.add_prompt_file: + self.save_prompt_file(count, max_count) + + def save_prompt_file(self, count: int = 0, max_count=0): + # save prompt file + with open(self.get_prompt_path(count, max_count), 'w') as f: + prompt = self.prompt + if self.prompt_2 is not None: + prompt += ' --p2 ' + self.prompt_2 + if self.negative_prompt is not None: + prompt += ' --n ' + self.negative_prompt + if self.negative_prompt_2 is not None: + prompt += ' --n2 ' + self.negative_prompt_2 + prompt += ' --w ' + str(self.width) + prompt += ' --h ' + str(self.height) + prompt += ' --seed ' + str(self.seed) + prompt += ' --cfg ' + str(self.guidance_scale) + prompt += ' --steps ' + str(self.num_inference_steps) + prompt += ' --m ' + str(self.network_multiplier) + prompt += ' --gr ' + str(self.guidance_rescale) + + # get gen info + try: + f.write(self.prompt) + except Exception as e: + print(f"Error writing prompt file. Prompt contains non-unicode characters. {e}") + + def _process_prompt_string(self): + # we will try to support all sd-scripts where we can + + # FROM SD-SCRIPTS + # --n Treat everything until the next option as a negative prompt. + # --w Specify the width of the generated image. + # --h Specify the height of the generated image. + # --d Specify the seed for the generated image. + # --l Specify the CFG scale for the generated image. + # --s Specify the number of steps during generation. + + # OURS and some QOL additions + # --m Specify the network multiplier for the generated image. + # --p2 Prompt for the second text encoder (SDXL only) + # --n2 Negative prompt for the second text encoder (SDXL only) + # --gr Specify the guidance rescale for the generated image (SDXL only) + + # --seed Specify the seed for the generated image same as --d + # --cfg Specify the CFG scale for the generated image same as --l + # --steps Specify the number of steps during generation same as --s + # --network_multiplier Specify the network multiplier for the generated image same as --m + + # process prompt string and update values if it has some + if self.prompt is not None and len(self.prompt) > 0: + # process prompt string + prompt = self.prompt + prompt = prompt.strip() + p_split = prompt.split('--') + self.prompt = p_split[0].strip() + + if len(p_split) > 1: + for split in p_split[1:]: + # allows multi char flags + flag = split.split(' ')[0].strip() + content = split[len(flag):].strip() + if flag == 'p2': + self.prompt_2 = content + elif flag == 'n': + self.negative_prompt = content + elif flag == 'n2': + self.negative_prompt_2 = content + elif flag == 'w': + self.width = int(content) + elif flag == 'h': + self.height = int(content) + elif flag == 'd': + self.seed = int(content) + elif flag == 'seed': + self.seed = int(content) + elif flag == 'l': + self.guidance_scale = float(content) + elif flag == 'cfg': + self.guidance_scale = float(content) + elif flag == 's': + self.num_inference_steps = int(content) + elif flag == 'steps': + self.num_inference_steps = int(content) + elif flag == 'm': + self.network_multiplier = float(content) + elif flag == 'network_multiplier': + self.network_multiplier = float(content) + elif flag == 'gr': + self.guidance_rescale = float(content) + elif flag == 'a': + self.adapter_conditioning_scale = float(content) + elif flag == 'ref': + self.refiner_start_at = float(content) + elif flag == 'ev': + # split by comma + self.extra_values = [float(val) for val in content.split(',')] + elif flag == 'extra_values': + # split by comma + self.extra_values = [float(val) for val in content.split(',')] + elif flag == 'frames': + self.num_frames = int(content) + elif flag == 'num_frames': + self.num_frames = int(content) + elif flag == 'fps': + self.fps = int(content) + elif flag == 'ctrl_img': + self.ctrl_img = content + elif flag == 'ctrl_idx': + self.ctrl_idx = int(content) + + def post_process_embeddings( + self, + conditional_prompt_embeds: PromptEmbeds, + unconditional_prompt_embeds: Optional[PromptEmbeds] = None, + ): + # this is called after prompt embeds are encoded. We can override them in the future here + pass + + def log_image(self, image, count: int = 0, max_count=0): + if self.logger is None: + return + + self.logger.log_image(image, count, self.prompt) + + +def validate_configs( + train_config: TrainConfig, + model_config: ModelConfig, + save_config: SaveConfig, + dataset_configs: List[DatasetConfig] +): + if model_config.is_flux: + if save_config.save_format != 'diffusers': + # make it diffusers + save_config.save_format = 'diffusers' + if model_config.use_flux_cfg: + # bypass the embedding + train_config.bypass_guidance_embedding = True + if train_config.bypass_guidance_embedding and train_config.do_guidance_loss: + raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. " + "Please set bypass_guidance_embedding to False or do_guidance_loss to False.") + + if model_config.accuracy_recovery_adapter is not None: + if model_config.assistant_lora_path is not None: + raise ValueError("Cannot use accuracy recovery adapter and assistant lora at the same time. " + "Please set one of them to None.") + + # see if any datasets are caching text embeddings + is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs) + if is_caching_text_embeddings: + + # check if they are doing differential output preservation + if train_config.diff_output_preservation: + raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.") + + # make sure they are all cached + for dataset in dataset_configs: + if not dataset.cache_text_embeddings: + raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.") + + # qwen image edit cannot cache text embeddings + if model_config.arch in ['qwen_image_edit', 'boogu_image_edit']: + if train_config.unload_text_encoder: + raise ValueError(f"Cannot cache unload text encoder with {model_config.arch} model. Control images are encoded with text embeddings. You can cache the text embeddings though") + + if train_config.diff_output_preservation and train_config.blank_prompt_preservation: + raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.") + + if train_config.batch_size > 1 and any(dataset_config.auto_frame_count for dataset_config in dataset_configs): + raise ValueError("Cannot use batch size greater than 1 with auto_frame_count. Please set batch_size to 1 or auto_frame_count to False.") + + diff --git a/ai-toolkit/toolkit/control_generator.py b/ai-toolkit/toolkit/control_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..88c03ee316456c2268c2807d3b7a23f08c947c82 --- /dev/null +++ b/ai-toolkit/toolkit/control_generator.py @@ -0,0 +1,331 @@ +import gc +import math +import os +import torch +from typing import Literal +from PIL import Image, ImageFilter, ImageOps +from PIL.ImageOps import exif_transpose +from tqdm import tqdm + +from torchvision import transforms + +# supress all warnings +import warnings + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +def flush(garbage_collect=True): + torch.cuda.empty_cache() + if garbage_collect: + gc.collect() + + +ControlTypes = Literal['depth', 'pose', 'line', 'inpaint', 'mask'] + +img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + + +class ControlGenerator: + def __init__(self, device, sd=None): + self.device = device + self.sd = sd # optional. It will unload the model if not None + self.has_unloaded = False + self.control_depth_model = None + self.control_pose_model = None + self.control_line_model = None + self.control_bg_remover = None + self.debug = False + self.regen = False + + def get_control_path(self, img_path, control_type: ControlTypes): + if self.regen: + return self._generate_control(img_path, control_type) + coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls') + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + file_name_no_ext_control = f"{file_name_no_ext}.{control_type}" + for ext in img_ext_list: + possible_path = os.path.join( + coltrols_folder, file_name_no_ext_control + ext) + if os.path.exists(possible_path): + return possible_path + # if we get here, we need to generate the control + return self._generate_control(img_path, control_type) + + def debug_print(self, *args, **kwargs): + if self.debug: + print(*args, **kwargs) + + def ensure_unloaded(self): + # unload the training model (if any) before generating controls + if not self.has_unloaded: + if self.sd is not None: + print("Unloading model to generate controls") + self.sd.set_device_state_preset('unload') + self.has_unloaded = True + + def load_image(self, img_path): + # CPU/disk stage: read, orient, and downscale to a max of 1mp + image = Image.open(img_path).convert('RGB') + image = exif_transpose(image) + + max_size = 1024 * 1024 + w, h = image.size + if w * h > max_size: + scale = math.sqrt(max_size / (w * h)) + w = int(w * scale) + h = int(h * scale) + image = image.resize((w, h), Image.BICUBIC) + return image + + def control_save_path(self, img_path, control_type): + coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls') + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + # inpaint needs alpha and mask is a near-binary single channel; webp + # compresses both far smaller than jpg. The rest stay jpg. + ext = 'webp' if control_type in ('inpaint', 'mask') else 'jpg' + return os.path.join( + coltrols_folder, f"{file_name_no_ext}.{control_type}.{ext}") + + def save_control(self, out_image, save_path): + # CPU/disk stage: encode and write the generated control + os.makedirs(os.path.dirname(save_path), exist_ok=True) + if save_path.lower().endswith('.webp'): + # method=6 trades CPU (already off the GPU thread) for smaller files + out_image.save(save_path, quality=80, method=6) + else: + out_image.save(save_path) + + def _bg_transform(self): + return transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + def _ensure_bg_remover(self): + if self.control_bg_remover is None: + from transformers import AutoModelForImageSegmentation + self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained( + 'ZhengPeng7/BiRefNet_HR', + trust_remote_code=True, + revision="a7a562f6fd16021180f2f4348f4de003a2d3d1e1", + dtype=torch.float16 + ).to(self.device) + self.control_bg_remover.eval() + + def preprocess(self, image, control_type): + # CPU stage. For the bg-remover path this does the expensive resize + + # normalize and returns a ready-to-run float16 tensor, so the GPU thread + # never has to. Other control types preprocess inside their model, so we + # just pass the PIL image straight through. + if control_type in ('inpaint', 'mask'): + return self._bg_transform()(image).unsqueeze(0).to(torch.float16) + return image + + def run_inference(self, payload, control_type): + # GPU stage. Returns an intermediate result for postprocess(). Models are + # lazily loaded here, so call from a single thread per generator instance. + self.ensure_unloaded() + if control_type in ('inpaint', 'mask'): + self._ensure_bg_remover() + x = payload.to(self.device).to(torch.float16) + with torch.inference_mode(): + preds = self.control_bg_remover(x)[-1].sigmoid().cpu() + return preds[0].squeeze() # CPU mask tensor, 1024x1024 + # everything else does preprocessing + inference together on this thread + return self.run_control(payload, control_type) + + def postprocess(self, result, image, control_type): + # CPU stage. Turns the inference result into the final control image. + if control_type in ('inpaint', 'mask'): + mask = transforms.ToPILImage()(result).resize(image.size) + if control_type == 'inpaint': + # inpainting currently only supports the "erased" section to inpaint + mask = ImageOps.invert(mask) + out = image.copy() + out.putalpha(mask) + return out + # keep the mask single-channel grayscale; the loader converts as + # needed and this roughly thirds the file size vs RGB + return mask + # the fallback path already produced a finished PIL image + return result + + def run_control(self, image, control_type): + # GPU stage: run inference on an already-loaded image and return the + # resulting PIL image (no disk IO). Models are lazily loaded here, so + # this must be called from a single thread per generator instance. + device = self.device + self.ensure_unloaded() + + if control_type == 'depth': + self.debug_print("Generating depth control") + if self.control_depth_model is None: + from transformers import pipeline + self.control_depth_model = pipeline( + task="depth-estimation", + model="depth-anything/Depth-Anything-V2-Large-hf", + device=device, + torch_dtype=torch.float16 + ) + img = image.copy() + in_size = img.size + output = self.control_depth_model(img) + out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255 + out_tensor = out_tensor.clamp(0, 255) + out_tensor = out_tensor.squeeze(0).cpu().numpy() + img = Image.fromarray(out_tensor.astype('uint8')) + img = img.resize(in_size, Image.LANCZOS) + return img + elif control_type == 'pose': + self.debug_print("Generating pose control") + if self.control_pose_model is None: + try: + import onnxruntime + onnxruntime.set_default_logger_severity(3) + except ImportError: + raise ImportError( + "onnxruntime is not installed. Please install it with pip install onnxruntime or onnxruntime-gpu") + try: + from easy_dwpose import DWposeDetector + self.control_pose_model = DWposeDetector( + device=str(device)) + except ImportError: + raise ImportError( + "easy-dwpose is not installed. Please install it with pip install git+https://github.com/jaretburkett/easy_dwpose.git") + img = image.copy() + + detect_res = int(math.sqrt(img.size[0] * img.size[1])) + img = self.control_pose_model( + img, output_type="pil", include_hands=True, include_face=True, detect_resolution=detect_res) + img = img.convert('RGB') + return img + + elif control_type == 'line': + self.debug_print("Generating line control") + if self.control_line_model is None: + from controlnet_aux import TEEDdetector + self.control_line_model = TEEDdetector.from_pretrained( + "fal-ai/teed", filename="5_model.pth").to(device) + img = image.copy() + img = self.control_line_model(img, detect_resolution=1024) + # apply threshold + # img = img.filter(ImageFilter.GaussianBlur(radius=1)) + img = img.point(lambda p: p > 128 and 255) + img = img.convert('RGB') + return img + elif control_type in ['inpaint', 'mask']: + self.debug_print("Generating inpaint/mask control") + # delegate to the staged methods so this matches the threaded path + payload = self.preprocess(image, control_type) + result = self.run_inference(payload, control_type) + return self.postprocess(result, image, control_type) + elif control_type in ['sapiens2_mask']: + self.debug_print("Generating sapiens2_mask control") + if self.control_bg_remover is None: + from toolkit.models.sapiens2 import Sapiens2Matting + self.control_bg_remover = Sapiens2Matting.from_pretrained( + device=device, + dtype=torch.float16 + ) + img = image.copy() + img = self.control_bg_remover(img) + return img + else: + raise Exception(f"Error: unknown control type {control_type}") + + def _generate_control(self, img_path, control_type): + image = self.load_image(img_path) + out_image = self.run_control(image, control_type) + save_path = self.control_save_path(img_path, control_type) + self.save_control(out_image, save_path) + return save_path + + def cleanup(self): + if self.control_depth_model is not None: + self.control_depth_model = None + if self.control_pose_model is not None: + self.control_pose_model = None + if self.control_line_model is not None: + self.control_line_model = None + if self.control_bg_remover is not None: + self.control_bg_remover = None + if self.sd is not None and self.has_unloaded: + self.sd.restore_device_state() + self.has_unloaded = False + + flush() + + +if __name__ == "__main__": + import sys + import argparse + import time + import transformers + transformers.logging.set_verbosity_error() + + control_times = { + 'depth': 0, + 'pose': 0, + 'line': 0, + 'inpaint': 0, + 'mask': 0 + } + + controls = control_times.keys() + + parser = argparse.ArgumentParser(description="Generate control images") + parser.add_argument("img_dir", type=str, help="Path to image directory") + parser.add_argument('--debug', action='store_true', + help="Enable debug mode") + parser.add_argument('--regen', action='store_true', + help="Regenerate all controls") + + args = parser.parse_args() + img_dir = args.img_dir + if not os.path.exists(img_dir): + print(f"Error: {img_dir} does not exist") + exit() + if not os.path.isdir(img_dir): + print(f"Error: {img_dir} is not a directory") + exit() + + # find images + img_list = [] + for root, dirs, files in os.walk(img_dir): + for file in files: + if "_controls" in root: + continue + if file.startswith('.'): + continue + if file.lower().endswith(tuple(img_ext_list)): + img_list.append(os.path.join(root, file)) + if len(img_list) == 0: + print(f"Error: no images found in {img_dir}") + exit() + + # load model + idx = 0 + for img_path in tqdm(img_list): + for control in controls: + start = time.time() + control_gen = ControlGenerator(torch.device('cuda')) + control_gen.debug = args.debug + control_gen.regen = args.regen + control_path = control_gen.get_control_path(img_path, control) + end = time.time() + # dont track for first 2 images + if idx < 2: + continue + control_times[control] += end - start + idx += 1 + + # determine avgt time + for control in controls: + control_times[control] /= (idx - 2) + print( + f"Avg time for {control} control: {control_times[control]:.2f} seconds") + + print("Done") diff --git a/ai-toolkit/toolkit/cuda_malloc.py b/ai-toolkit/toolkit/cuda_malloc.py new file mode 100644 index 0000000000000000000000000000000000000000..239b9666a83ea3f3838737b725902c6590ea19bc --- /dev/null +++ b/ai-toolkit/toolkit/cuda_malloc.py @@ -0,0 +1,93 @@ +# ref comfy ui +import os +import importlib.util + + +# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. +def get_gpu_names(): + if os.name == 'nt': + import ctypes + + # Define necessary C structures and types + class DISPLAY_DEVICEA(ctypes.Structure): + _fields_ = [ + ('cb', ctypes.c_ulong), + ('DeviceName', ctypes.c_char * 32), + ('DeviceString', ctypes.c_char * 128), + ('StateFlags', ctypes.c_ulong), + ('DeviceID', ctypes.c_char * 128), + ('DeviceKey', ctypes.c_char * 128) + ] + + # Load user32.dll + user32 = ctypes.windll.user32 + + # Call EnumDisplayDevicesA + def enum_display_devices(): + device_info = DISPLAY_DEVICEA() + device_info.cb = ctypes.sizeof(device_info) + device_index = 0 + gpu_names = set() + + while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): + device_index += 1 + gpu_names.add(device_info.DeviceString.decode('utf-8')) + return gpu_names + + return enum_display_devices() + else: + return set() + + +blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", + "GeForce 945M", + "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", + "Quadro K620", + "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", + "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", + "Quadro M5500", "Quadro M6000", + "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", + "GeForce GTX 1650", "GeForce GTX 1630" + } + + +def cuda_malloc_supported(): + try: + names = get_gpu_names() + except: + names = set() + for x in names: + if "NVIDIA" in x: + for b in blacklist: + if b in x: + return False + return True + + +cuda_malloc = False + +if not cuda_malloc: + try: + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: # enable by default for torch version 2.0 and up + cuda_malloc = cuda_malloc_supported() + except: + pass + +if cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" + + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + print("CUDA Malloc Async Enabled") diff --git a/ai-toolkit/toolkit/custom_adapter.py b/ai-toolkit/toolkit/custom_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..e409998807d8a3ab0babed818663a54571f9b395 --- /dev/null +++ b/ai-toolkit/toolkit/custom_adapter.py @@ -0,0 +1,1359 @@ +import math +import torch +import sys + +from PIL import Image +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, CLIPTextModel, \ + CLIPTokenizer, T5Tokenizer + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.models.clip_fusion import CLIPFusionModule +from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.control_lora_adapter import ControlLoraAdapter +from toolkit.models.mean_flow_adapter import MeanFlowAdapter +from toolkit.models.i2v_adapter import I2VAdapter +from toolkit.models.subpixel_adapter import SubpixelAdapter +from toolkit.models.ilora import InstantLoRAModule +from toolkit.models.single_value_adapter import SingleValueAdapter +from toolkit.models.te_adapter import TEAdapter +from toolkit.models.te_aug_adapter import TEAugAdapter +from toolkit.models.vd_adapter import VisionDirectAdapter +from toolkit.models.redux import ReduxImageEncoder +from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder +from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model +from toolkit.train_tools import get_torch_dtype +from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible +import random +from toolkit.util.mask import generate_random_mask +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict +from collections import OrderedDict +from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + ConvNextForImageClassification, + ConvNextImageProcessor, + UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer, BitsAndBytesConfig +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +from toolkit.models.llm_adapter import LLMAdapter + +import torch.nn.functional as F + + +class CustomAdapter(torch.nn.Module): + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig', train_config: 'TrainConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.train_config = train_config + self.device = self.sd_ref().unet.device + self.image_processor: CLIPImageProcessor = None + self.input_size = 224 + self.adapter_type: AdapterTypes = self.config.type + self.current_scale = 1.0 + self.is_active = True + self.flag_word = "fla9wor0" + self.is_unconditional_run = False + self.is_sampling = False + + self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None + + self.fuse_module: FuseModule = None + + self.lora: None = None + + self.position_ids: Optional[List[int]] = None + + self.num_control_images = self.config.num_control_images + self.token_mask: Optional[torch.Tensor] = None + + # setup clip + self.setup_clip() + # add for dataloader + self.clip_image_processor = self.image_processor + + self.clip_fusion_module: CLIPFusionModule = None + self.ilora_module: InstantLoRAModule = None + + self.te: Union[T5EncoderModel, CLIPTextModel] = None + self.tokenizer: CLIPTokenizer = None + self.te_adapter: TEAdapter = None + self.te_augmenter: TEAugAdapter = None + self.vd_adapter: VisionDirectAdapter = None + self.single_value_adapter: SingleValueAdapter = None + self.redux_adapter: ReduxImageEncoder = None + self.control_lora: ControlLoraAdapter = None + self.mean_flow_adapter: MeanFlowAdapter = None + self.subpixel_adapter: SubpixelAdapter = None + self.i2v_adapter: I2VAdapter = None + + self.conditional_embeds: Optional[torch.Tensor] = None + self.unconditional_embeds: Optional[torch.Tensor] = None + + self.cached_control_image_0_1: Optional[torch.Tensor] = None + + self.setup_adapter() + + if self.adapter_type == 'photo_maker': + # try to load from our name_or_path + if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'): + self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False) + # add the trigger word to the tokenizer + if isinstance(self.sd_ref().tokenizer, list): + for tokenizer in self.sd_ref().tokenizer: + tokenizer.add_tokens([self.flag_word], special_tokens=True) + else: + self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) + elif self.config.name_or_path is not None: + loaded_state_dict = load_custom_adapter_model( + self.config.name_or_path, + self.sd_ref().device, + dtype=self.sd_ref().dtype, + ) + self.load_state_dict(loaded_state_dict, strict=False) + + @property + def do_direct_save(self): + # some adapters save their weights directly, others like ip adapters split the state dict + if self.config.train_only_image_encoder: + return True + if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']: + return True + return False + + + def setup_adapter(self): + torch_dtype = get_torch_dtype(self.sd_ref().dtype) + if self.adapter_type == 'photo_maker': + sd = self.sd_ref() + embed_dim = sd.unet_unwrapped.config['cross_attention_dim'] + self.fuse_module = FuseModule(embed_dim) + elif self.adapter_type == 'clip_fusion': + sd = self.sd_ref() + embed_dim = sd.unet_unwrapped.config['cross_attention_dim'] + + vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) + if self.config.image_encoder_arch == 'clip': + vision_tokens = vision_tokens + 1 + self.clip_fusion_module = CLIPFusionModule( + text_hidden_size=embed_dim, + text_tokens=77, + vision_hidden_size=self.vision_encoder.config.hidden_size, + vision_tokens=vision_tokens + ) + elif self.adapter_type == 'ilora': + vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) + if self.config.image_encoder_arch == 'clip': + vision_tokens = vision_tokens + 1 + + vision_hidden_size = self.vision_encoder.config.hidden_size + + if self.config.clip_layer == 'image_embeds': + vision_tokens = 1 + vision_hidden_size = self.vision_encoder.config.projection_dim + + self.ilora_module = InstantLoRAModule( + vision_tokens=vision_tokens, + vision_hidden_size=vision_hidden_size, + head_dim=self.config.head_dim, + num_heads=self.config.num_heads, + sd=self.sd_ref(), + config=self.config + ) + elif self.adapter_type == 'text_encoder': + if self.config.text_encoder_arch == 't5': + te_kwargs = {} + # te_kwargs['load_in_4bit'] = True + # te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + self.te = T5EncoderModel.from_pretrained( + self.config.text_encoder_path, + torch_dtype=torch_dtype, + **te_kwargs + ) + + # self.te.to = lambda *args, **kwargs: None + self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path) + elif self.config.text_encoder_arch == 'pile-t5': + te_kwargs = {} + # te_kwargs['load_in_4bit'] = True + # te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + self.te = UMT5EncoderModel.from_pretrained( + self.config.text_encoder_path, + torch_dtype=torch_dtype, + **te_kwargs + ) + + # self.te.to = lambda *args, **kwargs: None + self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path) + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + elif self.config.text_encoder_arch == 'clip': + self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device, + dtype=torch_dtype) + self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path) + else: + raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}") + + self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer) + elif self.adapter_type == 'llm_adapter': + kwargs = {} + if self.config.quantize_llm: + bnb_kwargs = { + 'load_in_4bit': True, + 'bnb_4bit_quant_type': "nf4", + 'bnb_4bit_compute_dtype': torch.bfloat16 + } + quantization_config = BitsAndBytesConfig(**bnb_kwargs) + kwargs['quantization_config'] = quantization_config + kwargs['torch_dtype'] = torch_dtype + self.te = AutoModel.from_pretrained( + self.config.text_encoder_path, + **kwargs + ) + else: + self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to( + self.sd_ref().unet.device, + dtype=torch_dtype, + ) + self.te.to = lambda *args, **kwargs: None + self.te.eval() + self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path) + self.llm_adapter = LLMAdapter( + adapter=self, + sd=self.sd_ref(), + llm=self.te, + tokenizer=self.tokenizer, + num_cloned_blocks=self.config.num_cloned_blocks, + ) + self.llm_adapter.to(self.device, torch_dtype) + elif self.adapter_type == 'te_augmenter': + self.te_augmenter = TEAugAdapter(self, self.sd_ref()) + elif self.adapter_type == 'vision_direct': + self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder) + elif self.adapter_type == 'single_value': + self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens) + elif self.adapter_type == 'redux': + vision_hidden_size = self.vision_encoder.config.hidden_size + self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype) + elif self.adapter_type == 'mean_flow': + self.mean_flow_adapter = MeanFlowAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) + elif self.adapter_type == 'control_lora': + self.control_lora = ControlLoraAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) + elif self.adapter_type == 'i2v': + self.i2v_adapter = I2VAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config, + image_processor=self.image_processor, + vision_encoder=self.vision_encoder, + ) + elif self.adapter_type == 'subpixel': + self.subpixel_adapter = SubpixelAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) + else: + raise ValueError(f"unknown adapter type: {self.adapter_type}") + + def forward(self, *args, **kwargs): + # dont think this is used + # if self.adapter_type == 'photo_maker': + # id_pixel_values = args[0] + # prompt_embeds: PromptEmbeds = args[1] + # class_tokens_mask = args[2] + # + # grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled() + # + # with torch.set_grad_enabled(grads_on_image_encoder): + # id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False) + # + # if not grads_on_image_encoder: + # id_embeds = id_embeds.detach() + # + # prompt_embeds = prompt_embeds.detach() + # + # updated_prompt_embeds = self.fuse_module( + # prompt_embeds, id_embeds, class_tokens_mask + # ) + # + # return updated_prompt_embeds + # else: + raise NotImplementedError + + def edit_batch_raw(self, batch: DataLoaderBatchDTO): + # happens on a raw batch before latents are created + return batch + + def edit_batch_processed(self, batch: DataLoaderBatchDTO): + # happens after the latents are processed + if self.adapter_type == "i2v": + return self.i2v_adapter.edit_batch_processed(batch) + return batch + + def setup_clip(self): + adapter_config = self.config + sd = self.sd_ref() + if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]: + return + if self.config.type == 'photo_maker': + try: + self.image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) + except EnvironmentError: + self.image_processor = CLIPImageProcessor() + if self.config.image_encoder_path is None: + self.vision_encoder = PhotoMakerCLIPEncoder() + else: + self.vision_encoder = PhotoMakerCLIPEncoder.from_pretrained(self.config.image_encoder_path) + elif self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': + try: + self.image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = CLIPImageProcessor() + self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SiglipImageProcessor() + self.vision_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip2': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SiglipImageProcessor() + self.vision_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'pixtral': + self.image_processor = PixtralVisionImagePreprocessorCompatible( + max_image_size=self.config.pixtral_max_image_size, + ) + self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained( + adapter_config.image_encoder_path, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SAFEImageProcessor() + self.vision_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.safe_tokens, + num_vectors=sd.unet_unwrapped.config['cross_attention_dim'], + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnext': + try: + self.image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.image_processor = ConvNextImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.vision_encoder = ConvNextForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + else: + raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + + self.input_size = self.vision_encoder.config.image_size + + if self.config.quad_image: # 4x4 image + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.vision_encoder.config.image_size * 2 + + # update the preprocessor so images come in at the right size + if 'height' in self.image_processor.size: + self.image_processor.size['height'] = preprocessor_input_size + self.image_processor.size['width'] = preprocessor_input_size + elif hasattr(self.image_processor, 'crop_size'): + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size + + if self.config.image_encoder_arch == 'clip+': + # self.image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.vision_encoder.config.image_size * 4 + + # update the preprocessor so images come in at the right size + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size + + self.preprocessor = CLIPImagePreProcessor( + input_size=preprocessor_input_size, + clip_input_size=self.vision_encoder.config.image_size, + ) + if 'height' in self.image_processor.size: + self.input_size = self.image_processor.size['height'] + else: + self.input_size = self.image_processor.crop_size['height'] + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + if self.config.train_only_image_encoder and 'vd_adapter' not in state_dict and 'dvadapter' not in state_dict: + # we are loading pure clip weights. + self.vision_encoder.load_state_dict(state_dict, strict=strict) + + if 'lora_weights' in state_dict: + # todo add LoRA + # self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") + # self.sd_ref().pipeline.fuse_lora() + pass + if 'clip_fusion' in state_dict: + self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict) + if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'): + self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict) + # check to see if the fuse weights are there + fuse_weights = {} + for k, v in state_dict['id_encoder'].items(): + if k.startswith('fuse_module'): + k = k.replace('fuse_module.', '') + fuse_weights[k] = v + if len(fuse_weights) > 0: + try: + self.fuse_module.load_state_dict(fuse_weights, strict=strict) + except Exception as e: + + print(e) + # force load it + print(f"force loading fuse module as it did not match") + current_state_dict = self.fuse_module.state_dict() + for k, v in fuse_weights.items(): + if len(v.shape) == 1: + current_state_dict[k] = v[:current_state_dict[k].shape[0]] + elif len(v.shape) == 2: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]] + elif len(v.shape) == 3: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1], + :current_state_dict[k].shape[2]] + elif len(v.shape) == 4: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1], + :current_state_dict[k].shape[2], :current_state_dict[k].shape[3]] + else: + raise ValueError(f"unknown shape: {v.shape}") + self.fuse_module.load_state_dict(current_state_dict, strict=strict) + + if 'te_adapter' in state_dict: + self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict) + + if 'llm_adapter' in state_dict: + self.llm_adapter.load_state_dict(state_dict['llm_adapter'], strict=strict) + + if 'te_augmenter' in state_dict: + self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict) + + if 'vd_adapter' in state_dict: + self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict) + if 'dvadapter' in state_dict: + self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False) + + if 'sv_adapter' in state_dict: + self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict) + + if 'vision_encoder' in state_dict: + self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) + + if 'fuse_module' in state_dict: + self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) + + if 'ilora' in state_dict: + try: + self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict) + except Exception as e: + print(e) + if 'redux_up' in state_dict: + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.redux_adapter.load_state_dict(new_dict, strict=True) + + if self.adapter_type == 'control_lora': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.control_lora.load_weights(new_dict, strict=strict) + + if self.adapter_type == 'mean_flow': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.mean_flow_adapter.load_weights(new_dict, strict=strict) + + if self.adapter_type == 'i2v': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.i2v_adapter.load_weights(new_dict, strict=strict) + + if self.adapter_type == 'subpixel': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.subpixel_adapter.load_weights(new_dict, strict=strict) + + pass + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + if self.config.train_only_image_encoder: + return self.vision_encoder.state_dict() + + if self.adapter_type == 'photo_maker': + if self.config.train_image_encoder: + state_dict["id_encoder"] = self.vision_encoder.state_dict() + + state_dict["fuse_module"] = self.fuse_module.state_dict() + + # todo save LoRA + return state_dict + + elif self.adapter_type == 'clip_fusion': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["clip_fusion"] = self.clip_fusion_module.state_dict() + return state_dict + elif self.adapter_type == 'text_encoder': + state_dict["te_adapter"] = self.te_adapter.state_dict() + return state_dict + elif self.adapter_type == 'llm_adapter': + state_dict["llm_adapter"] = self.llm_adapter.state_dict() + return state_dict + elif self.adapter_type == 'te_augmenter': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["te_augmenter"] = self.te_augmenter.state_dict() + return state_dict + elif self.adapter_type == 'vision_direct': + state_dict["dvadapter"] = self.vd_adapter.state_dict() + # if self.config.train_image_encoder: # always return vision encoder + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + return state_dict + elif self.adapter_type == 'single_value': + state_dict["sv_adapter"] = self.single_value_adapter.state_dict() + return state_dict + elif self.adapter_type == 'ilora': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["ilora"] = self.ilora_module.state_dict() + return state_dict + elif self.adapter_type == 'redux': + d = self.redux_adapter.state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict + elif self.adapter_type == 'control_lora': + d = self.control_lora.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict + elif self.adapter_type == 'mean_flow': + d = self.mean_flow_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict + elif self.adapter_type == 'i2v': + d = self.i2v_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict + elif self.adapter_type == 'subpixel': + d = self.subpixel_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict + else: + raise NotImplementedError + + def add_extra_values(self, extra_values: torch.Tensor, is_unconditional=False): + if self.adapter_type == 'single_value': + if is_unconditional: + self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + else: + self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + + def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): + with torch.no_grad(): + # todo add i2v start frame conditioning here + + if self.adapter_type in ['i2v']: + return self.i2v_adapter.condition_noisy_latents(latents, batch) + elif self.adapter_type in ['control_lora']: + # inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor + # 4th channel is the mask with 1 being keep area and 0 being area to inpaint. + sd: StableDiffusion = self.sd_ref() + inpainting_latent = None + if self.config.has_inpainting_input: + do_dropout = random.random() < self.config.control_image_dropout + # do random mask if we dont have one + inpaint_tensor = batch.inpaint_tensor + if inpaint_tensor is None and not do_dropout: + # generate a random one since we dont have one + # this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha + inpaint_tensor = 1 - generate_random_mask( + batch_size=latents.shape[0], + height=latents.shape[2], + width=latents.shape[3], + device=latents.device, + ).to(latents.device, latents.dtype) + if inpaint_tensor is not None and not do_dropout: + + if inpaint_tensor.shape[1] == 4: + # get just the mask + inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype) + elif inpaint_tensor.shape[1] == 3: + # rgb mask. Just get one channel + inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype) + else: + inpainting_tensor_mask = inpaint_tensor + + # # use our batch latents so we cna avoid ancoding again + inpainting_latent = batch.latents + + # resize the mask to match the new encoded size + inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear') + inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype) + + do_mask_invert = False + if self.config.invert_inpaint_mask_chance > 0.0: + do_mask_invert = random.random() < self.config.invert_inpaint_mask_chance + if do_mask_invert: + # invert the mask + inpainting_tensor_mask = 1 - inpainting_tensor_mask + + # mask out the inpainting area, it is currently 0 for inpaint area, and 1 for keep area + # we are zeroing our the latents in the inpaint area not on the pixel space. + inpainting_latent = inpainting_latent * inpainting_tensor_mask + + # mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it. + inpainting_tensor_mask = 1 - inpainting_tensor_mask + # leave the mask as 0-1 and concat on channel of latents + inpainting_latent = torch.cat((inpainting_latent, inpainting_tensor_mask), dim=1) + else: + # we have iinpainting but didnt get a control. or we are doing a dropout + # the input needs to be all zeros for the latents and all 1s for the mask + inpainting_latent = torch.zeros_like(latents) + # add ones for the mask since we are technically inpainting everything + inpainting_latent = torch.cat((inpainting_latent, torch.ones_like(inpainting_latent[:, :1, :, :])), dim=1) + + if self.config.num_control_images == 1: + # this is our only control + control_latent = inpainting_latent.to(latents.device, latents.dtype) + latents = torch.cat((latents, control_latent), dim=1) + return latents.detach() + + if control_tensor is None: + # concat zeros onto the latents + ctrl = torch.zeros( + latents.shape[0], # bs + latents.shape[1] * self.num_control_images, # ch + latents.shape[2], + latents.shape[3], + device=latents.device, + dtype=latents.dtype + ) + if inpainting_latent is not None: + # inpainting always comes first + ctrl = torch.cat((inpainting_latent, ctrl), dim=1) + latents = torch.cat((latents, ctrl), dim=1) + return latents.detach() + # if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w] + # if we have 1, it comes in like [bs, ch, h, w] + # stack out control tensors to be [bs, ch * num_control_images, h, w] + + control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype) + + control_tensor_list = [] + if len(control_tensor.shape) == 4: + control_tensor_list.append(control_tensor) + else: + # reshape + control_tensor = control_tensor.view( + control_tensor.shape[0], + control_tensor.shape[1] * control_tensor.shape[2], + control_tensor.shape[3], + control_tensor.shape[4] + ) + control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1) + control_latent_list = [] + for control_tensor in control_tensor_list: + do_dropout = random.random() < self.config.control_image_dropout + if do_dropout: + # dropout with noise + control_latent_list.append(torch.zeros_like(batch.latents)) + else: + # it is 0-1 need to convert to -1 to 1 + control_tensor = control_tensor * 2 - 1 + + control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: + control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic') + + # encode it + control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype) + control_latent_list.append(control_latent) + # stack them on the channel dimension + control_latent = torch.cat(control_latent_list, dim=1) + if inpainting_latent is not None: + # inpainting always comes first + control_latent = torch.cat((inpainting_latent, control_latent), dim=1) + # concat it onto the latents + latents = torch.cat((latents, control_latent), dim=1) + return latents.detach() + return latents + + + def condition_prompt( + self, + prompt: Union[List[str], str], + is_unconditional: bool = False, + ): + if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']: + return prompt + elif self.adapter_type == 'text_encoder': + # todo allow for training + with torch.no_grad(): + # encode and save the embeds + if is_unconditional: + self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach() + else: + self.conditional_embeds = self.te_adapter.encode_text(prompt).detach() + elif self.adapter_type == 'llm_adapter': + # todo allow for training + with torch.no_grad(): + # encode and save the embeds + if is_unconditional: + self.unconditional_embeds = self.llm_adapter.encode_text(prompt).detach() + else: + self.conditional_embeds = self.llm_adapter.encode_text(prompt).detach() + return prompt + elif self.adapter_type == 'photo_maker': + if is_unconditional: + return prompt + else: + + with torch.no_grad(): + was_list = isinstance(prompt, list) + if not was_list: + prompt_list = [prompt] + else: + prompt_list = prompt + + new_prompt_list = [] + token_mask_list = [] + + for prompt in prompt_list: + + our_class = None + # find a class in the prompt + prompt_parts = prompt.split(' ') + prompt_parts = [p.strip().lower() for p in prompt_parts if len(p) > 0] + + new_prompt_parts = [] + tokened_prompt_parts = [] + for idx, prompt_part in enumerate(prompt_parts): + new_prompt_parts.append(prompt_part) + tokened_prompt_parts.append(prompt_part) + if prompt_part in self.config.class_names: + our_class = prompt_part + # add the flag word + tokened_prompt_parts.append(self.flag_word) + + if self.num_control_images > 1: + # add the rest + for _ in range(self.num_control_images - 1): + new_prompt_parts.extend(prompt_parts[idx + 1:]) + + # add the rest + tokened_prompt_parts.extend(prompt_parts[idx + 1:]) + new_prompt_parts.extend(prompt_parts[idx + 1:]) + + break + + prompt = " ".join(new_prompt_parts) + tokened_prompt = " ".join(tokened_prompt_parts) + + if our_class is None: + # add the first one to the front of the prompt + tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt + our_class = self.config.class_names[0] + prompt = " ".join( + [self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt + + # add the prompt to the list + new_prompt_list.append(prompt) + + # tokenize them with just the first tokenizer + tokenizer = self.sd_ref().tokenizer + if isinstance(tokenizer, list): + tokenizer = tokenizer[0] + + flag_token = tokenizer.convert_tokens_to_ids(self.flag_word) + + tokenized_prompt = tokenizer.encode(prompt) + tokenized_tokened_prompt = tokenizer.encode(tokened_prompt) + + flag_idx = tokenized_tokened_prompt.index(flag_token) + + class_token = tokenized_prompt[flag_idx - 1] + + boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool) + boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool))) + boolean_mask = boolean_mask.to(self.device) + # zero pad it to 77 + boolean_mask = F.pad(boolean_mask, (0, 77 - boolean_mask.shape[0]), value=False) + + token_mask_list.append(boolean_mask) + + self.token_mask = torch.cat(token_mask_list, dim=0).to(self.device) + + prompt_list = new_prompt_list + + if not was_list: + prompt = prompt_list[0] + else: + prompt = prompt_list + + return prompt + + else: + return prompt + + def condition_encoded_embeds( + self, + tensors_0_1: torch.Tensor, + prompt_embeds: PromptEmbeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=False, + quad_count=4, + is_generating_samples=False, + ) -> PromptEmbeds: + if self.adapter_type == 'text_encoder': + # replace the prompt embed with ours + if is_unconditional: + return self.unconditional_embeds.clone() + return self.conditional_embeds.clone() + if self.adapter_type == 'llm_adapter': + # replace the prompt embed with ours + if is_unconditional: + prompt_embeds.text_embeds = self.unconditional_embeds.text_embeds.clone() + prompt_embeds.attention_mask = self.unconditional_embeds.attention_mask.clone() + return prompt_embeds + prompt_embeds.text_embeds = self.conditional_embeds.text_embeds.clone() + prompt_embeds.attention_mask = self.conditional_embeds.attention_mask.clone() + return prompt_embeds + + if self.adapter_type == 'ilora': + return prompt_embeds + + if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'redux': + if is_unconditional: + # we dont condition the negative embeds for photo maker + return prompt_embeds.clone() + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + do_convert_rgb=True + ).pixel_values + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + if self.adapter_type == 'photo_maker': + # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image + clip_image = clip_image.unsqueeze(1) + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + do_projection2=isinstance(self.sd_ref().text_encoder, list), + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list) + ).detach() + + prompt_embeds.text_embeds = self.fuse_module( + prompt_embeds.text_embeds, + id_embeds, + self.token_mask + ) + return prompt_embeds + elif self.adapter_type == 'clip_fusion': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + img_embeds = id_embeds['last_hidden_state'] + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + prompt_embeds.text_embeds = self.clip_fusion_module( + prompt_embeds.text_embeds, + img_embeds + ) + return prompt_embeds + + elif self.adapter_type == 'redux': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + img_embeds = id_embeds['last_hidden_state'] + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + img_embeds = self.redux_adapter(img_embeds.to(self.device, get_torch_dtype(self.sd_ref().dtype))) + + prompt_embeds.text_embeds = torch.cat((prompt_embeds.text_embeds, img_embeds), dim=-2) + return prompt_embeds + else: + return prompt_embeds + + def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor: + with torch.no_grad(): + if shape is None: + shape = [batch_size, 3, self.input_size, self.input_size] + tensors_0_1 = torch.rand(shape, device=self.device) + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + return clip_image.detach() + + def train(self, mode: bool = True): + if self.config.train_image_encoder: + self.vision_encoder.train(mode) + super().train(mode) + + def trigger_pre_te( + self, + tensors_0_1: Optional[torch.Tensor]=None, + tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader + is_training=False, + has_been_preprocessed=False, + batch_tensor: Optional[torch.Tensor]=None, + quad_count=4, + batch_size=1, + ) -> PromptEmbeds: + if tensors_0_1 is not None: + # actual 0 - 1 image + self.cached_control_image_0_1 = tensors_0_1 + else: + # image has been processed through the dataloader and is prepped for vision encoder + self.cached_control_image_0_1 = None + if batch_tensor is not None and self.cached_control_image_0_1 is None: + # convert it to 0 - 1 + to_cache = batch_tensor / 2 + 0.5 + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + # if it is a video, just grad first frame + if len(to_cache.shape) == 5: + to_cache = to_cache[:, 0:1, :, :, :] + to_cache = to_cache.squeeze(1) + self.cached_control_image_0_1 = to_cache + + if tensors_preprocessed is not None and has_been_preprocessed: + tensors_0_1 = tensors_preprocessed + # if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']: + skip_unconditional = self.sd_ref().is_flux + if tensors_0_1 is None: + tensors_0_1 = self.get_empty_clip_image(batch_size) + has_been_preprocessed = True + + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + clip_image = tensors_0_1 + + # if is pixtral + if self.config.image_encoder_arch == 'pixtral' and self.config.pixtral_random_image_size: + # get the random size + random_size = random.randint(256, self.config.pixtral_max_image_size) + # images are already sized for max size, we have to fit them to the pixtral patch size to reduce / enlarge it farther. + h, w = clip_image.shape[2], clip_image.shape[3] + current_base_size = int(math.sqrt(w * h)) + ratio = current_base_size / random_size + if ratio > 1: + w = round(w / ratio) + h = round(h / ratio) + + width_tokens = (w - 1) // self.image_processor.image_patch_size + 1 + height_tokens = (h - 1) // self.image_processor.image_patch_size + 1 + assert width_tokens > 0 + assert height_tokens > 0 + + new_image_size = ( + width_tokens * self.image_processor.image_patch_size, + height_tokens * self.image_processor.image_patch_size, + ) + + # resize the image + clip_image = F.interpolate(clip_image, size=new_image_size, mode='bicubic', align_corners=False) + + + batch_size = clip_image.shape[0] + if self.config.control_image_dropout > 0 and is_training: + clip_batch = torch.chunk(clip_image, batch_size, dim=0) + unconditional_batch = torch.chunk(self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( + clip_image.device, dtype=clip_image.dtype + ), batch_size, dim=0) + combine_list = [] + for i in range(batch_size): + do_dropout = random.random() < self.config.control_image_dropout + if do_dropout: + # dropout with noise + combine_list.append(unconditional_batch[i]) + else: + combine_list.append(clip_batch[i]) + clip_image = torch.cat(combine_list, dim=0) + + if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v'] and not skip_unconditional: + # add an unconditional so we can save it + unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( + clip_image.device, dtype=clip_image.dtype + ) + clip_image = torch.cat([unconditional, clip_image], dim=0) + + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + if self.adapter_type == 'ilora': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + img_embeds = id_embeds.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + img_embeds = id_embeds.hidden_states[-1] + elif self.config.clip_layer == 'image_embeds': + img_embeds = id_embeds.image_embeds + else: + raise ValueError(f"unknown clip layer: {self.config.clip_layer}") + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + self.ilora_module(img_embeds) + # if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v']: + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + else: + with torch.no_grad(): + self.vision_encoder.eval() + self.vision_encoder.to(self.device) + clip_output = self.vision_encoder( + clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)), + output_hidden_states=True, + ) + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + if hasattr(clip_output, 'image_embeds'): + clip_image_embeds = clip_output.image_embeds + elif hasattr(clip_output, 'pooler_output'): + clip_image_embeds = clip_output.pooler_output + # TODO should we always norm image embeds? + # get norm embeddings + # l2_norm = torch.norm(clip_image_embeds, p=2) + # clip_image_embeds = clip_image_embeds / l2_norm + + if not is_training or not self.config.train_image_encoder: + clip_image_embeds = clip_image_embeds.detach() + + if self.adapter_type == 'te_augmenter': + clip_image_embeds = self.te_augmenter(clip_image_embeds) + + if self.adapter_type == 'vision_direct': + clip_image_embeds = self.vd_adapter(clip_image_embeds) + + # save them to the conditional and unconditional + try: + if skip_unconditional: + self.unconditional_embeds, self.conditional_embeds = None, clip_image_embeds + else: + self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) + except ValueError: + raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}") + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.train_only_image_encoder: + yield from self.vision_encoder.parameters(recurse) + return + if self.config.type == 'photo_maker': + yield from self.fuse_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'clip_fusion': + yield from self.clip_fusion_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'ilora': + yield from self.ilora_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'text_encoder': + for attn_processor in self.te_adapter.adapter_modules: + yield from attn_processor.parameters(recurse) + elif self.config.type == 'llm_adapter': + yield from self.llm_adapter.parameters(recurse) + elif self.config.type == 'vision_direct': + if self.config.train_scaler: + # only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules) + yield self.vd_adapter.block_scaler + else: + for attn_processor in self.vd_adapter.adapter_modules: + yield from attn_processor.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + if self.vd_adapter.resampler is not None: + yield from self.vd_adapter.resampler.parameters(recurse) + if self.vd_adapter.pool is not None: + yield from self.vd_adapter.pool.parameters(recurse) + if self.vd_adapter.sparse_autoencoder is not None: + yield from self.vd_adapter.sparse_autoencoder.parameters(recurse) + elif self.config.type == 'te_augmenter': + yield from self.te_augmenter.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'single_value': + yield from self.single_value_adapter.parameters(recurse) + elif self.config.type == 'redux': + yield from self.redux_adapter.parameters(recurse) + elif self.config.type == 'control_lora': + param_list = self.control_lora.get_params() + for param in param_list: + yield param + elif self.config.type == 'mean_flow': + param_list = self.mean_flow_adapter.get_params() + for param in param_list: + yield param + elif self.config.type == 'i2v': + param_list = self.i2v_adapter.get_params() + for param in param_list: + yield param + elif self.config.type == 'subpixel': + param_list = self.subpixel_adapter.get_params() + for param in param_list: + yield param + else: + raise NotImplementedError + + def enable_gradient_checkpointing(self): + if hasattr(self.vision_encoder, "enable_gradient_checkpointing"): + self.vision_encoder.enable_gradient_checkpointing() + elif hasattr(self.vision_encoder, 'gradient_checkpointing'): + self.vision_encoder.gradient_checkpointing = True + + def get_additional_save_metadata(self) -> Dict[str, Any]: + additional = {} + if self.config.type == 'ilora': + extra = self.ilora_module.get_additional_save_metadata() + for k, v in extra.items(): + additional[k] = v + additional['clip_layer'] = self.config.clip_layer + additional['image_encoder_arch'] = self.config.head_dim + return additional + + def post_weight_update(self): + # do any kind of updates after the weight update + if self.config.type == 'vision_direct': + self.vd_adapter.post_weight_update() + pass \ No newline at end of file diff --git a/ai-toolkit/toolkit/data_loader.py b/ai-toolkit/toolkit/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ee90e26310a588e3313ae1d7e21c877a58a754dc --- /dev/null +++ b/ai-toolkit/toolkit/data_loader.py @@ -0,0 +1,773 @@ +import copy +import json +import os +import random +import traceback +from functools import lru_cache +from typing import List, TYPE_CHECKING + +import cv2 +import numpy as np +import torch +from PIL import Image +from PIL.ImageOps import exif_transpose +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader, ConcatDataset +from tqdm import tqdm +import albumentations as A + +from toolkit import image_utils +from toolkit.buckets import get_bucket_for_image_size, BucketResolution +from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin, TextEmbeddingCachingMixin +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.print import print_acc +from toolkit.accelerator import get_accelerator + +import platform + +def is_native_windows(): + return platform.system() == "Windows" and platform.release() != "2" + +def is_macos(): + return platform.system() == "Darwin" + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +image_extensions = ['.jpg', '.jpeg', '.png', '.webp'] +video_extensions = ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv'] +audio_extensions = ['.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a'] + + +class RescaleTransform: + """Transform to rescale images to the range [-1, 1].""" + + def __call__(self, image): + return image * 2 - 1 + + +class NormalizeSDXLTransform: + """ + Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images + + Mean: tensor([ 0.0002, -0.1034, -0.1879]) + Standard Deviation: tensor([0.5436, 0.5116, 0.5033]) + """ + + def __call__(self, image): + return transforms.Normalize( + mean=[0.0002, -0.1034, -0.1879], + std=[0.5436, 0.5116, 0.5033], + )(image) + + +class NormalizeSD15Transform: + """ + Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images + + Mean: tensor([-0.1600, -0.2450, -0.3227]) + Standard Deviation: tensor([0.5319, 0.4997, 0.5139]) + + """ + + def __call__(self, image): + return transforms.Normalize( + mean=[-0.1600, -0.2450, -0.3227], + std=[0.5319, 0.4997, 0.5139], + )(image) + + + +class ImageDataset(Dataset, CaptionMixin): + def __init__(self, config): + self.config = config + self.name = self.get_config('name', 'dataset') + self.path = self.get_config('path', required=True) + self.scale = self.get_config('scale', 1) + self.random_scale = self.get_config('random_scale', False) + self.include_prompt = self.get_config('include_prompt', False) + self.default_prompt = self.get_config('default_prompt', '') + if self.include_prompt: + self.caption_type = self.get_config('caption_ext', 'txt') + else: + self.caption_type = None + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) + + self.resolution = self.get_config('resolution', 256) + self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] + + # this might take a while + print_acc(f" - Preprocessing image dimensions") + new_file_list = [] + bad_count = 0 + for file in tqdm(self.file_list): + try: + w, h = image_utils.get_image_size(file) + except image_utils.UnknownImageFormat: + img = exif_transpose(Image.open(file)) + w, h = img.size + # img = Image.open(file) + if int(min([w, h]) * self.scale) >= self.resolution: + new_file_list.append(file) + else: + bad_count += 1 + + self.file_list = new_file_list + + print_acc(f" - Found {len(self.file_list)} images") + print_acc(f" - Found {bad_count} images that are too small") + assert len(self.file_list) > 0, f"no images found in {self.path}" + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + + def get_config(self, key, default=None, required=False): + if key in self.config: + value = self.config[key] + return value + elif required: + raise ValueError(f'config file error. Missing "config.dataset.{key}" key') + else: + return default + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + img_path = self.file_list[index] + try: + img = exif_transpose(Image.open(img_path)).convert('RGB') + except Exception as e: + print_acc(f"Error opening image: {img_path}") + print_acc(e) + # make a noise image if we can't open it + img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)) + + # Downscale the source image first + img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) + min_img_size = min(img.size) + + if self.random_crop: + if self.random_scale and min_img_size > self.resolution: + if min_img_size < self.resolution: + print_acc( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") + scale_size = self.resolution + else: + scale_size = random.randint(self.resolution, int(min_img_size)) + scaler = scale_size / min_img_size + scale_width = int((img.width + 5) * scaler) + scale_height = int((img.height + 5) * scaler) + img = img.resize((scale_width, scale_height), Image.BICUBIC) + img = transforms.RandomCrop(self.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.resolution, self.resolution), Image.BICUBIC) + + img = self.transform(img) + + if self.include_prompt: + prompt = self.get_caption_item(index) + return img, prompt + else: + return img + + + + + +class AugmentedImageDataset(ImageDataset): + def __init__(self, config): + super().__init__(config) + self.augmentations = self.get_config('augmentations', []) + self.augmentations = [Augments(**aug) for aug in self.augmentations] + + augmentation_list = [] + for aug in self.augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.aug_transform = A.Compose(augmentation_list) + self.original_transform = self.transform + # replace transform so we get raw pil image + self.transform = transforms.Compose([]) + + def __getitem__(self, index): + # get the original image + # image is a PIL image, convert to bgr + pil_image = super().__getitem__(index) + open_cv_image = np.array(pil_image) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + augmented = self.aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + # return both # return image as 0 - 1 tensor + return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented) + + +class PairedImageDataset(Dataset): + def __init__(self, config): + super().__init__() + self.config = config + self.size = self.get_config('size', 512) + self.path = self.get_config('path', None) + self.pos_folder = self.get_config('pos_folder', None) + self.neg_folder = self.get_config('neg_folder', None) + + self.default_prompt = self.get_config('default_prompt', '') + self.network_weight = self.get_config('network_weight', 1.0) + self.pos_weight = self.get_config('pos_weight', self.network_weight) + self.neg_weight = self.get_config('neg_weight', self.network_weight) + + supported_exts = ('.jpg', '.jpeg', '.png', '.webp', '.JPEG', '.JPG', '.PNG', '.WEBP') + + if self.pos_folder is not None and self.neg_folder is not None: + # find matching files + self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if + file.lower().endswith(supported_exts)] + self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if + file.lower().endswith(supported_exts)] + + matched_files = [] + for pos_file in self.pos_file_list: + pos_file_no_ext = os.path.splitext(pos_file)[0] + for neg_file in self.neg_file_list: + neg_file_no_ext = os.path.splitext(neg_file)[0] + if os.path.basename(pos_file_no_ext) == os.path.basename(neg_file_no_ext): + matched_files.append((neg_file, pos_file)) + break + + # remove duplicates + matched_files = [t for t in (set(tuple(i) for i in matched_files))] + + self.file_list = matched_files + print_acc(f" - Found {len(self.file_list)} matching pairs") + else: + self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if + file.lower().endswith(supported_exts)] + print_acc(f" - Found {len(self.file_list)} images") + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + + def get_all_prompts(self): + prompts = [] + for index in range(len(self.file_list)): + prompts.append(self.get_prompt_item(index)) + + # remove duplicates + prompts = list(set(prompts)) + return prompts + + def __len__(self): + return len(self.file_list) + + def get_config(self, key, default=None, required=False): + if key in self.config: + value = self.config[key] + return value + elif required: + raise ValueError(f'config file error. Missing "config.dataset.{key}" key') + else: + return default + + def get_prompt_item(self, index): + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # check if either has a prompt file + path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] + prompt_path = path_no_ext + '.txt' + if not os.path.exists(prompt_path): + path_no_ext = os.path.splitext(img_path_or_tuple[1])[0] + prompt_path = path_no_ext + '.txt' + else: + img_path = img_path_or_tuple + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = path_no_ext + '.txt' + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # remove any newlines + prompt = prompt.replace('\n', ', ') + # remove new lines for all operating systems + prompt = prompt.replace('\r', ', ') + prompt_split = prompt.split(',') + # remove empty strings + prompt_split = [p.strip() for p in prompt_split if p.strip()] + # join back together + prompt = ', '.join(prompt_split) + else: + prompt = self.default_prompt + return prompt + + def __getitem__(self, index): + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # load both images + img_path = img_path_or_tuple[0] + img1 = exif_transpose(Image.open(img_path)).convert('RGB') + img_path = img_path_or_tuple[1] + img2 = exif_transpose(Image.open(img_path)).convert('RGB') + + # always use # 2 (pos) + bucket_resolution = get_bucket_for_image_size( + width=img2.width, + height=img2.height, + resolution=self.size, + # divisibility=self. + ) + + # images will be same base dimension, but may be trimmed. We need to shrink and then central crop + if bucket_resolution['width'] > bucket_resolution['height']: + img1_scale_to_height = bucket_resolution["height"] + img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height)) + img2_scale_to_height = bucket_resolution["height"] + img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height)) + else: + img1_scale_to_width = bucket_resolution["width"] + img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width)) + img2_scale_to_width = bucket_resolution["width"] + img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width)) + + img1_crop_height = bucket_resolution["height"] + img1_crop_width = bucket_resolution["width"] + img2_crop_height = bucket_resolution["height"] + img2_crop_width = bucket_resolution["width"] + + # scale then center crop images + img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC) + img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1) + img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC) + img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2) + + # combine them side by side + img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) + img.paste(img1, (0, 0)) + img.paste(img2, (img1.width, 0)) + else: + img_path = img_path_or_tuple + img = exif_transpose(Image.open(img_path)).convert('RGB') + height = self.size + # determine width to keep aspect ratio + width = int(img.size[0] * height / img.size[1]) + + # Downscale the source image first + img = img.resize((width, height), Image.BICUBIC) + + prompt = self.get_prompt_item(index) + img = self.transform(img) + + return img, prompt, (self.neg_weight, self.pos_weight) + + +class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, TextEmbeddingCachingMixin, BucketsMixin, CaptionMixin, Dataset): + + def __init__( + self, + dataset_config: 'DatasetConfig', + batch_size=1, + sd: 'StableDiffusion' = None, + ): + self.dataset_config = dataset_config + # update bucket divisibility + self.dataset_config.bucket_tolerance = sd.get_bucket_divisibility() + self.is_video = dataset_config.num_frames > 1 or dataset_config.auto_frame_count + self.is_audio_model = hasattr(sd, 'is_audio_model') and sd.is_audio_model if sd is not None else False + super().__init__() + folder_path = dataset_config.folder_path + self.dataset_path = dataset_config.dataset_path + if self.dataset_path is None: + self.dataset_path = folder_path + + self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk + self.is_caching_latents_to_memory = dataset_config.cache_latents + self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk + self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk + self.is_generating_controls = len(dataset_config.controls) > 0 + self.epoch_num = 0 + + self.sd = sd + + if self.sd is None and self.is_caching_latents: + raise ValueError(f"sd is required for caching latents") + + self.caption_type = dataset_config.caption_ext + self.default_caption = dataset_config.default_caption + self.random_scale = dataset_config.random_scale + self.scale = dataset_config.scale + self.batch_size = batch_size + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop + self.resolution = dataset_config.resolution + self.caption_dict = None + self.file_list: List['FileItemDTO'] = [] + + # check if dataset_path is a folder or json + if os.path.isdir(self.dataset_path): + extensions = image_extensions + if self.is_audio_model: + # only look for audio files + extensions = audio_extensions + elif self.is_video: + # only look for videos + extensions = video_extensions + file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(tuple(extensions)) and not file.startswith('.')] + else: + # assume json + with open(self.dataset_path, 'r') as f: + self.caption_dict = json.load(f) + # keys are file paths + file_list = list(self.caption_dict.keys()) + + # remove items in the _controls_ folder + file_list = [x for x in file_list if not os.path.basename(os.path.dirname(x)) == "_controls"] + + if self.dataset_config.num_repeats > 1: + # repeat the list + file_list = file_list * self.dataset_config.num_repeats + + if self.dataset_config.standardize_images: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + NormalizeMethod = NormalizeSDXLTransform + else: + NormalizeMethod = NormalizeSD15Transform + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + NormalizeMethod(), + ]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + + # this might take a while + print_acc(f"Dataset: {self.dataset_path}") + if self.is_video: + print_acc(f" - Preprocessing video dimensions") + else: + print_acc(f" - Preprocessing image dimensions") + dataset_folder = self.dataset_path + if not os.path.isdir(self.dataset_path): + dataset_folder = os.path.dirname(dataset_folder) + + dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') + dataloader_version = "0.1.2" + if os.path.exists(dataset_size_file): + try: + with open(dataset_size_file, 'r') as f: + self.size_database = json.load(f) + + if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version: + print_acc("Upgrading size database to new version") + # old version, delete and recreate + self.size_database = {} + except Exception as e: + print_acc(f"Error loading size database: {dataset_size_file}") + print_acc(e) + self.size_database = {} + else: + self.size_database = {} + + self.size_database["__version__"] = dataloader_version + + # set latent space version + latent_space_version = "sd1" + if self.sd is not None and self.sd.model_config.latent_space_version is not None: + latent_space_version = self.sd.model_config.latent_space_version + elif self.sd is not None and self.sd.latent_space_version is not None: + latent_space_version = self.sd.latent_space_version + elif self.sd.is_xl: + latent_space_version = 'sdxl' + elif self.sd.is_v3: + latent_space_version = 'sd3' + elif self.sd.is_auraflow: + latent_space_version = 'sdxl' + elif self.sd.is_flux: + latent_space_version = 'flux1' + elif self.sd.model_config.is_pixart_sigma: + latent_space_version = 'sdxl' + else: + latent_space_version = self.sd.model_config.arch if self.sd is not None else "sd1" + + temporal_compression = 8 + if self.sd is not None: + if hasattr(self.sd.vae, 'config') and hasattr(self.sd.vae.config, 'scale_factor_temporal'): + temporal_compression = self.sd.vae.config.scale_factor_temporal + if hasattr(self.sd.unet, 'config') and hasattr(self.sd.unet.config, 'temporal_compression_ratio'): + temporal_compression = self.sd.unet.config.temporal_compression_ratio + + bad_count = 0 + for file in tqdm(file_list): + try: + file_item = FileItemDTO( + sd=self.sd, + path=file, + is_audio_model=self.is_audio_model, + dataset_config=dataset_config, + dataloader_transforms=self.transform, + size_database=self.size_database, + dataset_root=dataset_folder, + encode_control_in_text_embeddings=self.sd.encode_control_in_text_embeddings if self.sd else False, + text_embedding_space_version=self.sd.text_embedding_space_version if self.sd else "sd1", + te_padding_side=self.sd.te_padding_side if self.sd else "right", + latent_space_version=latent_space_version, + temporal_compression=temporal_compression, + sample_rate=self.sd.sample_rate if self.is_audio_model and self.sd is not None else 48000, + ) + self.file_list.append(file_item) + except Exception as e: + print_acc(traceback.format_exc()) + if self.is_video: + print_acc(f"Error processing video: {file}") + else: + print_acc(f"Error processing image: {file}") + print_acc(e) + bad_count += 1 + + # save the size database + with open(dataset_size_file, 'w') as f: + json.dump(self.size_database, f) + + if self.is_video: + print_acc(f" - Found {len(self.file_list)} videos") + assert len(self.file_list) > 0, f"no videos found in {self.dataset_path}" + else: + print_acc(f" - Found {len(self.file_list)} images") + assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" + + # handle x axis flips + if self.dataset_config.flip_x: + print_acc(" - adding x axis flips") + current_file_list = [x for x in self.file_list] + for file_item in current_file_list: + # create a copy that is flipped on the x axis + new_file_item = copy.deepcopy(file_item) + new_file_item.flip_x = True + self.file_list.append(new_file_item) + + # handle y axis flips + if self.dataset_config.flip_y: + print_acc(" - adding y axis flips") + current_file_list = [x for x in self.file_list] + for file_item in current_file_list: + # create a copy that is flipped on the y axis + new_file_item = copy.deepcopy(file_item) + new_file_item.flip_y = True + self.file_list.append(new_file_item) + + if self.dataset_config.flip_x or self.dataset_config.flip_y: + if self.is_video: + print_acc(f" - Found {len(self.file_list)} videos after adding flips") + else: + print_acc(f" - Found {len(self.file_list)} images after adding flips") + + self.setup_epoch() + + def setup_epoch(self): + if self.epoch_num == 0: + # initial setup + # do not call for now + if self.dataset_config.buckets: + # setup buckets + self.setup_buckets() + if self.is_caching_latents: + self.cache_latents_all_latents() + if self.is_caching_clip_vision_to_disk: + self.cache_clip_vision_to_disk() + if self.is_caching_text_embeddings: + self.cache_text_embeddings() + if self.is_generating_controls: + # always do this last + self.setup_controls() + self.epoch_num += 1 + + def __len__(self): + if self.dataset_config.buckets: + return len(self.batch_indices) + return len(self.file_list) + + def _get_replacement_index(self, index) -> int: + # when an image fails to load we have to swap in a different one. With buckets the + # replacement must come from the same bucket so the collated shapes still match. + if self.dataset_config.buckets: + for bucket in self.buckets.values(): + if index in bucket.file_list_idx: + candidates = [i for i in bucket.file_list_idx if i != index] + if candidates: + return random.choice(candidates) + break + return random.randint(0, len(self.file_list) - 1) + + def _get_single_item(self, index, _attempts=0) -> 'FileItemDTO': + file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) + try: + file_item.load_and_process_image(self.transform) + except Exception as e: + print(f"Error loading image, skipping and loading a different one: {file_item.path} ({e})") + if _attempts >= 10: + # avoid infinite recursion if many files are corrupt + raise + new_index = self._get_replacement_index(index) + return self._get_single_item(new_index, _attempts=_attempts + 1) + file_item.load_caption(self.caption_dict) + return file_item + + def __getitem__(self, item): + if self.dataset_config.buckets: + # for buckets we collate ourselves for now + # todo allow a scheduler to dynamically make buckets + # we collate ourselves + if len(self.batch_indices) - 1 < item: + # tried everything to solve this. No way to reset length when redoing things. Pick another index + item = random.randint(0, len(self.batch_indices) - 1) + idx_list = self.batch_indices[item] + return [self._get_single_item(idx) for idx in idx_list] + else: + # Dataloader is batching + return self._get_single_item(item) + + +def get_dataloader_from_datasets( + dataset_options, + batch_size=1, + sd: 'StableDiffusion' = None, +) -> DataLoader: + if dataset_options is None or len(dataset_options) == 0: + return None + + datasets = [] + has_buckets = False + is_caching_latents = False + + dataset_config_list = [] + # preprocess them all + for dataset_option in dataset_options: + if isinstance(dataset_option, DatasetConfig): + dataset_config_list.append(dataset_option) + else: + # preprocess raw data + split_configs = preprocess_dataset_raw_config([dataset_option]) + for x in split_configs: + dataset_config_list.append(DatasetConfig(**x)) + + for config in dataset_config_list: + + if config.type == 'image': + dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd) + datasets.append(dataset) + if config.buckets: + has_buckets = True + if config.cache_latents or config.cache_latents_to_disk: + is_caching_latents = True + else: + raise ValueError(f"invalid dataset type: {config.type}") + + concatenated_dataset = ConcatDataset(datasets) + + # todo build scheduler that can get buckets from all datasets that match + # todo and evenly distribute reg images + + def dto_collation(batch: List['FileItemDTO']): + # create DTO batch + batch = DataLoaderBatchDTO( + file_items=batch + ) + return batch + + # check if is caching latents + + dataloader_kwargs = {} + + if is_native_windows() or is_macos(): + dataloader_kwargs['num_workers'] = 0 + else: + dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers + dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor + + if has_buckets: + # make sure they all have buckets + for dataset in datasets: + assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none" + + data_loader = DataLoader( + concatenated_dataset, + batch_size=None, # we batch in the datasets for now + drop_last=False, + shuffle=True, + collate_fn=dto_collation, # Use the custom collate function + **dataloader_kwargs + ) + else: + data_loader = DataLoader( + concatenated_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=dto_collation, + **dataloader_kwargs + ) + return data_loader + + +def trigger_dataloader_setup_epoch(dataloader: DataLoader): + # hacky but needed because of different types of datasets and dataloaders + dataloader.len = None + if isinstance(dataloader.dataset, list): + for dataset in dataloader.dataset: + if hasattr(dataset, 'datasets'): + for sub_dataset in dataset.datasets: + if hasattr(sub_dataset, 'setup_epoch'): + sub_dataset.setup_epoch() + sub_dataset.len = None + elif hasattr(dataset, 'setup_epoch'): + dataset.setup_epoch() + dataset.len = None + elif hasattr(dataloader.dataset, 'setup_epoch'): + dataloader.dataset.setup_epoch() + dataloader.dataset.len = None + elif hasattr(dataloader.dataset, 'datasets'): + dataloader.dataset.len = None + for sub_dataset in dataloader.dataset.datasets: + if hasattr(sub_dataset, 'setup_epoch'): + sub_dataset.setup_epoch() + sub_dataset.len = None + +def get_dataloader_datasets(dataloader: DataLoader): + # hacky but needed because of different types of datasets and dataloaders + if isinstance(dataloader.dataset, list): + datasets = [] + for dataset in dataloader.dataset: + if hasattr(dataset, 'datasets'): + for sub_dataset in dataset.datasets: + datasets.append(sub_dataset) + else: + datasets.append(dataset) + return datasets + elif hasattr(dataloader.dataset, 'datasets'): + return dataloader.dataset.datasets + else: + return [dataloader.dataset] diff --git a/ai-toolkit/toolkit/data_transfer_object/data_loader.py b/ai-toolkit/toolkit/data_transfer_object/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..d07a6d6df3e533458af71605e41626e90d61ba4d --- /dev/null +++ b/ai-toolkit/toolkit/data_transfer_object/data_loader.py @@ -0,0 +1,478 @@ +import os +from typing import TYPE_CHECKING, List, Union +import cv2 +import torch + +from PIL import Image +from PIL.ImageOps import exif_transpose +import av + +from toolkit import image_utils +from toolkit.basic import get_quick_signature_string +from toolkit.dataloader_mixins import ( + CaptionProcessingDTOMixin, + ImageProcessingDTOMixin, + LatentCachingFileItemDTOMixin, + ControlFileItemDTOMixin, + ArgBreakMixin, + MaskFileItemDTOMixin, + AugmentationFileItemDTOMixin, + UnconditionalFileItemDTOMixin, + ClipImageFileItemDTOMixin, + InpaintControlFileItemDTOMixin, + TextEmbeddingFileItemDTOMixin, + AudioProcessingDTOMixin, +) +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds + +if TYPE_CHECKING: + from toolkit.config_modules import DatasetConfig + +printed_messages = [] + + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + + +class FileItemDTO( + LatentCachingFileItemDTOMixin, + TextEmbeddingFileItemDTOMixin, + CaptionProcessingDTOMixin, + ImageProcessingDTOMixin, + AudioProcessingDTOMixin, + ControlFileItemDTOMixin, + InpaintControlFileItemDTOMixin, + ClipImageFileItemDTOMixin, + MaskFileItemDTOMixin, + AugmentationFileItemDTOMixin, + UnconditionalFileItemDTOMixin, + ArgBreakMixin, +): + def __init__(self, *args, **kwargs): + self.path = kwargs.get("path", "") + self.dataset_config: "DatasetConfig" = kwargs.get("dataset_config", None) + self.is_video = self.dataset_config.num_frames > 1 or self.dataset_config.auto_frame_count + self.is_audio_model = kwargs.get("is_audio_model", False) + self.sample_rate = kwargs.get("sample_rate", 48000) + self.num_frames = self.dataset_config.num_frames + self.temporal_compression = kwargs.get("temporal_compression", 8) + size_database = kwargs.get("size_database", {}) + dataset_root = kwargs.get("dataset_root", None) + self.encode_control_in_text_embeddings = kwargs.get( + "encode_control_in_text_embeddings", False + ) + self.te_padding_side = kwargs.get("te_padding_side", "right") + self.latent_space_version = kwargs.get("latent_space_version", "sd1") + self.text_embedding_space_version = kwargs.get("text_embedding_space_version", "sd1") + if dataset_root is not None: + # remove dataset root from path + file_key = self.path.replace(dataset_root, "") + else: + file_key = os.path.basename(self.path) + + file_signature = get_quick_signature_string(self.path) + if file_signature is None: + raise Exception("Error: Could not get file signature for {self.path}") + + use_db_entry = False + if file_key in size_database: + db_entry = size_database[file_key] + if ( + db_entry is not None + and len(db_entry) >= 3 + and db_entry[2] == file_signature + ): + use_db_entry = True + if self.is_audio_model: + # get the length of the audio file in ms + with av.open(self.path) as c: + if c.duration is not None: + w = int(c.duration / 1_000) + else: + s = c.streams.audio[0] + w = int(float(s.duration * s.time_base) * 1_000) + h = 1 + elif use_db_entry: + w, h, _ = size_database[file_key] + elif self.is_video: + # Open the video file + video = cv2.VideoCapture(self.path) + + # Check if video opened successfully + if not video.isOpened(): + raise Exception(f"Error: Could not open video file {self.path}") + + # Get width and height + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + w, h = width, height + + # Release the video capture object immediately + video.release() + size_database[file_key] = (width, height, file_signature) + else: + if self.dataset_config.fast_image_size: + # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default. + try: + w, h = image_utils.get_image_size(self.path) + except image_utils.UnknownImageFormat: + print_once( + f"Warning: Some images in the dataset cannot be fast read. " + + f"This process is faster for png, jpeg" + ) + img = exif_transpose(Image.open(self.path)) + w, h = img.size + else: + img = exif_transpose(Image.open(self.path)) + w, h = img.size + size_database[file_key] = (w, h, file_signature) + self.width: int = w + self.height: int = h + self.dataloader_transforms = kwargs.get("dataloader_transforms", None) + super().__init__(*args, **kwargs) + + # self.caption_path: str = kwargs.get('caption_path', None) + self.raw_caption: str = kwargs.get("raw_caption", None) + # we scale first, then crop + self.scale_to_width: int = kwargs.get( + "scale_to_width", int(self.width * self.dataset_config.scale) + ) + self.scale_to_height: int = kwargs.get( + "scale_to_height", int(self.height * self.dataset_config.scale) + ) + # crop values are from scaled size + self.crop_x: int = kwargs.get("crop_x", 0) + self.crop_y: int = kwargs.get("crop_y", 0) + self.crop_width: int = kwargs.get("crop_width", self.scale_to_width) + self.crop_height: int = kwargs.get("crop_height", self.scale_to_height) + self.flip_x: bool = kwargs.get("flip_x", False) + self.flip_y: bool = kwargs.get("flip_x", False) + self.augments: List[str] = self.dataset_config.augments + self.loss_multiplier: float = self.dataset_config.loss_multiplier + + self.network_weight: float = self.dataset_config.network_weight + self.is_reg = self.dataset_config.is_reg + self.prior_reg = self.dataset_config.prior_reg + self.tensor: Union[torch.Tensor, None] = None + self.audio_data = None + self.audio_tensor = None + + def cleanup(self): + self.tensor = None + self.audio_data = None + self.audio_tensor = None + self.cleanup_latent() + self.cleanup_text_embedding() + self.cleanup_control() + self.cleanup_inpaint() + self.cleanup_clip_image() + self.cleanup_mask() + self.cleanup_unconditional() + + +class DataLoaderBatchDTO: + def __init__(self, **kwargs): + try: + self.file_items: List["FileItemDTO"] = kwargs.get("file_items", None) + is_latents_cached = self.file_items[0].is_latent_cached + self.tensor: Union[torch.Tensor, None] = None + self.latents: Union[torch.Tensor, None] = None + self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None + self.clip_image_tensor: Union[torch.Tensor, None] = None + self.mask_tensor: Union[torch.Tensor, None] = None + self.unaugmented_tensor: Union[torch.Tensor, None] = None + self.unconditional_tensor: Union[torch.Tensor, None] = None + self.unconditional_latents: Union[torch.Tensor, None] = None + self.clip_image_embeds: Union[List[dict], None] = None + self.clip_image_embeds_unconditional: Union[List[dict], None] = None + self.sigmas: Union[torch.Tensor, None] = ( + None # can be added elseware and passed along training code + ) + self.extra_values: Union[torch.Tensor, None] = ( + torch.tensor([x.extra_values for x in self.file_items]) + if len(self.file_items[0].extra_values) > 0 + else None + ) + self.audio_data: Union[List, None] = ( + [x.audio_data for x in self.file_items] + if self.file_items[0].audio_data is not None + else None + ) + self.audio_tensor: Union[torch.Tensor, None] = None + self.first_frame_latents: Union[torch.Tensor, None] = None + self.audio_latents: Union[torch.Tensor, None] = None + + # just for holding noise and preds during training + self.audio_target: Union[torch.Tensor, None] = None + self.audio_pred: Union[torch.Tensor, None] = None + + self.num_frames: int = self.file_items[0].num_frames + + if not is_latents_cached or self.file_items[0].dataset_config.load_image_when_caching_latents: + # only return a tensor if latents are not cached, or if we are explicitly + # loading the raw image alongside the cached latents + self.tensor: torch.Tensor = torch.cat( + [x.tensor.unsqueeze(0) for x in self.file_items] + ) + # if we have encoded latents, we concatenate them + self.latents: Union[torch.Tensor, None] = None + if is_latents_cached: + # this get_latent call with trigger loading all cached items from the disk + self.latents = torch.cat( + [x.get_latent().unsqueeze(0) for x in self.file_items] + ) + if any( + [x._cached_first_frame_latent is not None for x in self.file_items] + ): + self.first_frame_latents = torch.cat( + [ + x._cached_first_frame_latent.unsqueeze(0) + if x._cached_first_frame_latent is not None + else torch.zeros_like( + self.file_items[0]._cached_first_frame_latent + ).unsqueeze(0) + for x in self.file_items + ] + ) + if any([x._cached_audio_latent is not None for x in self.file_items]): + self.audio_latents = torch.cat( + [ + x._cached_audio_latent.unsqueeze(0) + if x._cached_audio_latent is not None + else torch.zeros_like( + self.file_items[0]._cached_audio_latent + ).unsqueeze(0) + for x in self.file_items + ] + ) + + self.prompt_embeds: Union[PromptEmbeds, None] = None + # if self.file_items[0].control_tensor is not None: + # if any have a control tensor, we concatenate them + if any([x.control_tensor is not None for x in self.file_items]): + # find one to use as a base + base_control_tensor = None + for x in self.file_items: + if x.control_tensor is not None: + base_control_tensor = x.control_tensor + break + control_tensors = [] + for x in self.file_items: + if x.control_tensor is None: + control_tensors.append(torch.zeros_like(base_control_tensor)) + else: + control_tensors.append(x.control_tensor) + self.control_tensor = torch.cat( + [x.unsqueeze(0) for x in control_tensors] + ) + + # handle control tensor list + if any([x.control_tensor_list is not None for x in self.file_items]): + self.control_tensor_list = [] + for x in self.file_items: + if x.control_tensor_list is not None: + self.control_tensor_list.append(x.control_tensor_list) + else: + raise Exception( + f"Could not find control tensors for all file items, missing for {x.path}" + ) + + self.inpaint_tensor: Union[torch.Tensor, None] = None + if any([x.inpaint_tensor is not None for x in self.file_items]): + # find one to use as a base + base_inpaint_tensor = None + for x in self.file_items: + if x.inpaint_tensor is not None: + base_inpaint_tensor = x.inpaint_tensor + break + inpaint_tensors = [] + for x in self.file_items: + if x.inpaint_tensor is None: + inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor)) + else: + inpaint_tensors.append(x.inpaint_tensor) + self.inpaint_tensor = torch.cat( + [x.unsqueeze(0) for x in inpaint_tensors] + ) + + self.loss_multiplier_list: List[float] = [ + x.loss_multiplier for x in self.file_items + ] + + if any([x.clip_image_tensor is not None for x in self.file_items]): + # find one to use as a base + base_clip_image_tensor = None + for x in self.file_items: + if x.clip_image_tensor is not None: + base_clip_image_tensor = x.clip_image_tensor + break + clip_image_tensors = [] + for x in self.file_items: + if x.clip_image_tensor is None: + clip_image_tensors.append( + torch.zeros_like(base_clip_image_tensor) + ) + else: + clip_image_tensors.append(x.clip_image_tensor) + self.clip_image_tensor = torch.cat( + [x.unsqueeze(0) for x in clip_image_tensors] + ) + + if any([x.mask_tensor is not None for x in self.file_items]): + # find one to use as a base + base_mask_tensor = None + for x in self.file_items: + if x.mask_tensor is not None: + base_mask_tensor = x.mask_tensor + break + mask_tensors = [] + for x in self.file_items: + if x.mask_tensor is None: + mask_tensors.append(torch.zeros_like(base_mask_tensor)) + else: + mask_tensors.append(x.mask_tensor) + self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors]) + + # add unaugmented tensors for ones with augments + if any([x.unaugmented_tensor is not None for x in self.file_items]): + # find one to use as a base + base_unaugmented_tensor = None + for x in self.file_items: + if x.unaugmented_tensor is not None: + base_unaugmented_tensor = x.unaugmented_tensor + break + unaugmented_tensor = [] + for x in self.file_items: + if x.unaugmented_tensor is None: + unaugmented_tensor.append( + torch.zeros_like(base_unaugmented_tensor) + ) + else: + unaugmented_tensor.append(x.unaugmented_tensor) + self.unaugmented_tensor = torch.cat( + [x.unsqueeze(0) for x in unaugmented_tensor] + ) + + # add unconditional tensors + if any([x.unconditional_tensor is not None for x in self.file_items]): + # find one to use as a base + base_unconditional_tensor = None + for x in self.file_items: + if x.unaugmented_tensor is not None: + base_unconditional_tensor = x.unconditional_tensor + break + unconditional_tensor = [] + for x in self.file_items: + if x.unconditional_tensor is None: + unconditional_tensor.append( + torch.zeros_like(base_unconditional_tensor) + ) + else: + unconditional_tensor.append(x.unconditional_tensor) + self.unconditional_tensor = torch.cat( + [x.unsqueeze(0) for x in unconditional_tensor] + ) + + if any([x.clip_image_embeds is not None for x in self.file_items]): + self.clip_image_embeds = [] + for x in self.file_items: + if x.clip_image_embeds is not None: + self.clip_image_embeds.append(x.clip_image_embeds) + else: + raise Exception("clip_image_embeds is None for some file items") + + if any( + [x.clip_image_embeds_unconditional is not None for x in self.file_items] + ): + self.clip_image_embeds_unconditional = [] + for x in self.file_items: + if x.clip_image_embeds_unconditional is not None: + self.clip_image_embeds_unconditional.append( + x.clip_image_embeds_unconditional + ) + else: + raise Exception( + "clip_image_embeds_unconditional is None for some file items" + ) + + if any([x.prompt_embeds is not None for x in self.file_items]): + # find one to use as a base + base_prompt_embeds = None + for x in self.file_items: + if x.prompt_embeds is not None: + base_prompt_embeds = x.prompt_embeds + break + prompt_embeds_list = [] + for x in self.file_items: + if x.prompt_embeds is None: + y = base_prompt_embeds + else: + y = x.prompt_embeds + if x.text_embedding_space_version == "zimage": + # z image needs to be a list if it is not already + if not isinstance(y.text_embeds, list): + y.text_embeds = [y.text_embeds] + prompt_embeds_list.append(y) + padding_side = self.file_items[0].te_padding_side + + self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list, padding_side=padding_side) + + if any([x.audio_tensor is not None for x in self.file_items]): + # find one to use as a base + base_audio_tensor = None + for x in self.file_items: + if x.audio_tensor is not None: + base_audio_tensor = x.audio_tensor + break + audio_tensors = [] + for x in self.file_items: + if x.audio_tensor is None: + audio_tensors.append(torch.zeros_like(base_audio_tensor)) + else: + audio_tensors.append(x.audio_tensor) + self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors]) + + except Exception as e: + print(e) + raise e + + def get_is_reg_list(self): + return [x.is_reg for x in self.file_items] + + def get_network_weight_list(self): + return [x.network_weight for x in self.file_items] + + def get_caption_list( + self, trigger=None, to_replace_list=None, add_if_not_present=True + ): + return [x.caption for x in self.file_items] + + def get_caption_short_list( + self, trigger=None, to_replace_list=None, add_if_not_present=True + ): + return [x.caption_short for x in self.file_items] + + def cleanup(self): + del self.latents + del self.tensor + del self.control_tensor + del self.audio_tensor + del self.audio_data + del self.audio_target + del self.audio_pred + del self.first_frame_latents + del self.audio_latents + for file_item in self.file_items: + file_item.cleanup() + + @property + def dataset_config(self) -> "DatasetConfig": + if len(self.file_items) > 0: + return self.file_items[0].dataset_config + else: + return None diff --git a/ai-toolkit/toolkit/dataloader_mixins.py b/ai-toolkit/toolkit/dataloader_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..f878e9414a29071f18585b5ec2a425d6467cbd1c --- /dev/null +++ b/ai-toolkit/toolkit/dataloader_mixins.py @@ -0,0 +1,2157 @@ +import base64 +import glob +import hashlib +import json +import math +import os +import random +from collections import OrderedDict +from typing import TYPE_CHECKING, List, Dict, Union +import traceback + +import cv2 +import numpy as np +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor + +from toolkit.audio.preserve_pitch import time_stretch_preserve_pitch +from toolkit.basic import flush, value_map +from toolkit.buckets import get_bucket_for_image_size, get_resolution +from toolkit.config_modules import ControlTypes +from toolkit.control_generator import ControlGenerator +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible +from toolkit.prompt_utils import inject_trigger_into_prompt +from torchvision import transforms +from PIL import Image, ImageFilter, ImageOps +from PIL.ImageOps import exif_transpose +import albumentations as A +from toolkit.print import print_acc +from toolkit.accelerator import get_accelerator +from toolkit.prompt_utils import PromptEmbeds +from torchvision.transforms import functional as TF + +from toolkit.train_tools import get_torch_dtype + +if TYPE_CHECKING: + from toolkit.data_loader import AiToolkitDataset + from toolkit.data_transfer_object.data_loader import FileItemDTO + from toolkit.stable_diffusion_model import StableDiffusion + +accelerator = get_accelerator() + +# def get_associated_caption_from_img_path(img_path): +# https://demo.albumentations.ai/ +class Augments: + def __init__(self, **kwargs): + self.method_name = kwargs.get('method', None) + self.params = kwargs.get('params', {}) + + # convert kwargs enums for cv2 + for key, value in self.params.items(): + if isinstance(value, str): + # split the string + split_string = value.split('.') + if len(split_string) == 2 and split_string[0] == 'cv2': + if hasattr(cv2, split_string[1]): + self.params[key] = getattr(cv2, split_string[1].upper()) + else: + raise ValueError(f"invalid cv2 enum: {split_string[1]}") + + +transforms_dict = { + 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03), + 'RandomEqualize': transforms.RandomEqualize(p=0.2), +} + +img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + + +def standardize_images(images): + """ + Standardize the given batch of images using the specified mean and std. + Expects values of 0 - 1 + + Args: + images (torch.Tensor): A batch of images in the shape of (N, C, H, W), + where N is the number of images, C is the number of channels, + H is the height, and W is the width. + + Returns: + torch.Tensor: Standardized images. + """ + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # Define the normalization transform + normalize = transforms.Normalize(mean=mean, std=std) + + # Apply normalization to each image in the batch + standardized_images = torch.stack([normalize(img) for img in images]) + + return standardized_images + +def clean_caption(caption): + # this doesnt make any sense anymore in a world that is not based on comma seperated tokens + # # remove any newlines + # caption = caption.replace('\n', ', ') + # # remove new lines for all operating systems + # caption = caption.replace('\r', ', ') + # caption_split = caption.split(',') + # # remove empty strings + # caption_split = [p.strip() for p in caption_split if p.strip()] + # # join back together + # caption = ', '.join(caption_split) + return caption + +def waveform_to_stereo(waveform): + c = waveform.shape[0] + if c == 2: + return waveform + if c == 1: + return waveform.expand(2, -1) + if c == 6: # 5.1: FL, FR, FC, LFE, BL, BR + fl, fr, fc, _, bl, br = waveform + k = 0.7071 + return torch.stack([fl + k * fc + k * bl, fr + k * fc + k * br]) + if c == 8: # 7.1: FL, FR, FC, LFE, BL, BR, SL, SR + fl, fr, fc, _, bl, br, sl, sr = waveform + k = 0.7071 + return torch.stack([fl + k * fc + k * (bl + sl), fr + k * fc + k * (br + sr)]) + return waveform.mean(0, keepdim=True).expand(2, -1) + + +class CaptionMixin: + def get_caption_item(self: 'AiToolkitDataset', index): + if not hasattr(self, 'caption_type'): + raise Exception('caption_type not found on class instance') + if not hasattr(self, 'file_list'): + raise Exception('file_list not found on class instance') + img_path_or_tuple = self.file_list[index] + ext = self.dataset_config.caption_ext + if isinstance(img_path_or_tuple, tuple): + img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path + # check if either has a prompt file + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = None + prompt_path = path_no_ext + ext + else: + img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = path_no_ext + ext + + # allow folders to have a default prompt + default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt') + default_prompt_path_with_ext = os.path.join(os.path.dirname(img_path), 'default' + ext) + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + prompt = clean_caption(prompt) + elif os.path.exists(default_prompt_path_with_ext): + with open(default_prompt_path_with_ext, 'r', encoding='utf-8') as f: + prompt = f.read() + prompt = clean_caption(prompt) + elif os.path.exists(default_prompt_path): + with open(default_prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + prompt = clean_caption(prompt) + else: + prompt = '' + # get default_prompt if it exists on the class instance + if hasattr(self, 'default_prompt'): + prompt = self.default_prompt + if hasattr(self, 'default_caption'): + prompt = self.default_caption + + # handle replacements + replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else [] + for replacement in replacement_list: + from_string, to_string = replacement.split('|') + prompt = prompt.replace(from_string, to_string) + + return prompt + + +if TYPE_CHECKING: + from toolkit.config_modules import DatasetConfig + from toolkit.data_transfer_object.data_loader import FileItemDTO + + +class Bucket: + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.file_list_idx: List[int] = [] + + +class BucketsMixin: + def __init__(self): + self.buckets: Dict[str, Bucket] = {} + self.batch_indices: List[List[int]] = [] + + def build_batch_indices(self: 'AiToolkitDataset'): + self.batch_indices = [] + for key, bucket in self.buckets.items(): + for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): + end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) + batch = bucket.file_list_idx[start_idx:end_idx] + # if the bucket has fewer items left than the requested batch size, + # duplicate items from this batch to pad it up to batch_size + if len(batch) < self.batch_size and len(batch) > 0: + pad = [batch[i % len(batch)] for i in range(self.batch_size - len(batch))] + batch = batch + pad + self.batch_indices.append(batch) + + def shuffle_buckets(self: 'AiToolkitDataset'): + for key, bucket in self.buckets.items(): + random.shuffle(bucket.file_list_idx) + + def setup_buckets(self: 'AiToolkitDataset', quiet=False): + if not hasattr(self, 'file_list'): + raise Exception(f'file_list not found on class instance {self.__class__.__name__}') + if not hasattr(self, 'dataset_config'): + raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') + + if self.epoch_num > 0: + # no need to rebuild buckets for now + # todo handle random cropping for buckets + return + self.buckets = {} # clear it + + config: 'DatasetConfig' = self.dataset_config + resolution = config.resolution + bucket_tolerance = config.bucket_tolerance + file_list: List['FileItemDTO'] = self.file_list + + # for file_item in enumerate(file_list): + for idx, file_item in enumerate(file_list): + file_item: 'FileItemDTO' = file_item + if self.is_audio_model: + bucket_key = f"{file_item.width}ms" + if bucket_key not in self.buckets: + self.buckets[bucket_key] = Bucket(file_item.width, 1) + self.buckets[bucket_key].file_list_idx.append(idx) + continue + width = int(file_item.width * file_item.dataset_config.scale) + height = int(file_item.height * file_item.dataset_config.scale) + + if self.dataset_config.square_crop: + # we scale first so smallest size matches resolution + scale_factor_x = resolution / width + scale_factor_y = resolution / height + scale_factor = max(scale_factor_x, scale_factor_y) + file_item.scale_to_width = math.ceil(width * scale_factor) + file_item.scale_to_height = math.ceil(height * scale_factor) + file_item.crop_width = resolution + file_item.crop_height = resolution + if width > height: + file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2) + file_item.crop_y = 0 + else: + file_item.crop_x = 0 + file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2) + else: + bucket_resolution = get_bucket_for_image_size( + width, height, + resolution=resolution, + divisibility=bucket_tolerance + ) + + # Calculate scale factors for width and height + width_scale_factor = bucket_resolution["width"] / width + height_scale_factor = bucket_resolution["height"] / height + + # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution + max_scale_factor = max(width_scale_factor, height_scale_factor) + + # round up + file_item.scale_to_width = int(math.ceil(width * max_scale_factor)) + file_item.scale_to_height = int(math.ceil(height * max_scale_factor)) + + file_item.crop_height = bucket_resolution["height"] + file_item.crop_width = bucket_resolution["width"] + + new_width = bucket_resolution["width"] + new_height = bucket_resolution["height"] + + if self.dataset_config.random_crop: + # random crop + crop_x = random.randint(0, file_item.scale_to_width - new_width) + crop_y = random.randint(0, file_item.scale_to_height - new_height) + file_item.crop_x = crop_x + file_item.crop_y = crop_y + else: + # do central crop + file_item.crop_x = int((file_item.scale_to_width - new_width) / 2) + file_item.crop_y = int((file_item.scale_to_height - new_height) / 2) + + if file_item.crop_y < 0 or file_item.crop_x < 0: + print_acc('debug') + + # check if bucket exists, if not, create it + bucket_key = f'{file_item.crop_width}x{file_item.crop_height}' + if bucket_key not in self.buckets: + self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height) + self.buckets[bucket_key].file_list_idx.append(idx) + + # print the buckets + self.shuffle_buckets() + self.build_batch_indices() + if not quiet: + print_acc(f'Bucket sizes for {self.dataset_path}:') + for key, bucket in self.buckets.items(): + print_acc(f'{key}: {len(bucket.file_list_idx)} files') + print_acc(f'{len(self.buckets)} buckets made') + + +class CaptionProcessingDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.raw_caption: str = None + self.raw_caption_short: str = None + self.caption: str = None + self.caption_short: str = None + + dataset_config: DatasetConfig = kwargs.get('dataset_config', None) + self.extra_values: List[float] = dataset_config.extra_values + self.trigger_word = dataset_config.trigger_word + + # todo allow for loading from sd-scripts style dict + def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None): + if self.raw_caption is not None: + # we already loaded it + pass + elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]: + self.raw_caption = caption_dict[self.path]["caption"] + if 'caption_short' in caption_dict[self.path]: + self.raw_caption_short = caption_dict[self.path]["caption_short"] + if self.dataset_config.use_short_captions: + self.raw_caption = caption_dict[self.path]["caption_short"] + else: + # see if prompt file exists + path_no_ext = os.path.splitext(self.path)[0] + prompt_ext = self.dataset_config.caption_ext + prompt_path = path_no_ext + prompt_ext + short_caption = None + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + short_caption = None + prompt = clean_caption(prompt) + if short_caption is not None: + short_caption = clean_caption(short_caption) + + if prompt.strip() == '' and self.dataset_config.default_caption is not None: + prompt = self.dataset_config.default_caption + else: + prompt = '' + if self.dataset_config.default_caption is not None: + prompt = self.dataset_config.default_caption + + if short_caption is None: + short_caption = self.dataset_config.default_caption + self.raw_caption = prompt + self.raw_caption_short = short_caption + + self.caption = self.get_caption() + if self.raw_caption_short is not None: + self.caption_short = self.get_caption(short_caption=True) + + def get_caption( + self: 'FileItemDTO', + trigger=None, + to_replace_list=None, + add_if_not_present=False, + short_caption=False + ): + if trigger is None and self.trigger_word is not None: + trigger = self.trigger_word + + if trigger is not None and not self.is_reg: + # add if not present if not regularization + add_if_not_present = True + + if short_caption: + raw_caption = self.raw_caption_short + else: + raw_caption = self.raw_caption + if raw_caption is None: + raw_caption = '' + # handle dropout + if self.dataset_config.caption_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings: + # get a random float form 0 to 1 + rand = random.random() + if rand < self.dataset_config.caption_dropout_rate: + # drop the caption + return '' + + # get tokens + token_list = raw_caption.split(',') + + # handle token dropout + if self.dataset_config.token_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings: + new_token_list = [] + keep_tokens: int = self.dataset_config.keep_tokens + for idx, token in enumerate(token_list): + if idx < keep_tokens: + new_token_list.append(token) + elif self.dataset_config.token_dropout_rate >= 1.0: + # drop the token + pass + else: + # get a random float form 0 to 1 + rand = random.random() + if rand > self.dataset_config.token_dropout_rate: + # keep the token + new_token_list.append(token) + token_list = new_token_list + + if self.dataset_config.shuffle_tokens: + random.shuffle(token_list) + + # join back together + caption = ', '.join(token_list) + caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + + if self.dataset_config.random_triggers: + num_triggers = self.dataset_config.random_triggers_max + if num_triggers > 1: + num_triggers = random.randint(0, num_triggers) + + if num_triggers > 0: + triggers = random.sample(self.dataset_config.random_triggers, num_triggers) + caption = caption + ', ' + ', '.join(triggers) + # add random triggers + # for i in range(num_triggers): + # # fastest method + # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))] + # caption = caption + ', ' + trigger + + if self.dataset_config.shuffle_tokens: + # shuffle again + token_list = caption.split(',') + random.shuffle(token_list) + caption = ', '.join(token_list) + if caption == '': + pass + return caption + +class AudioProcessingDTOMixin: + def load_and_process_audio(self: 'FileItemDTO'): + # Default to "no audio" unless we successfully extract it + self.audio_data = None + self.audio_tensor = None + self.tensor = None + try: + import torchaudio + + waveform, sample_rate = torchaudio.load(self.path) # [channels, samples] + waveform = waveform_to_stereo(waveform) # Convert to stereo if not already + if sample_rate != self.sample_rate: + waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate) + self.tensor = waveform + self.audio_tensor = waveform + self.audio_data = {"waveform": waveform, "sample_rate": int(self.sample_rate)} + + except Exception as e: + # if issue with libtorchcodec "Could not load libtorchcodec" + raise Exception(f"** WARNING ** - Error Processing audio for {self.path}. Error: {e}") + + +class ImageProcessingDTOMixin: + def load_and_process_video( + self: 'FileItemDTO', + transform: Union[None, transforms.Compose], + only_load_latents=False + ): + + if self.augments is not None and len(self.augments) > 0: + raise Exception('Augments not supported for videos') + + if self.has_augmentations: + raise Exception('Augmentations not supported for videos') + + if not self.dataset_config.buckets: + raise Exception('Buckets required for video processing') + + do_audio = self.dataset_config.do_audio + + try: + # Use OpenCV to capture video frames + cap = cv2.VideoCapture(self.path) + + if not cap.isOpened(): + raise Exception(f"Failed to open video file: {self.path}") + + # Get video properties + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + video_fps = cap.get(cv2.CAP_PROP_FPS) + + # Calculate the max valid frame index (accounting for zero-indexing) + max_frame_index = total_frames - 1 + + # Only log video properties if in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Video properties: {self.path}") + print_acc(f" Total frames: {total_frames}") + print_acc(f" Max valid frame index: {max_frame_index}") + print_acc(f" FPS: {video_fps}") + + frames_to_extract = [] + + if self.dataset_config.auto_frame_count: + # allow for any length video here but make sure it is temporally compressable. + vid_length_seconds = total_frames / video_fps + + desired_num_frames = int(vid_length_seconds * self.dataset_config.fps) + + # make sure it is divisible by temporal_compression + desired_num_frames = desired_num_frames // self.temporal_compression * self.temporal_compression + + # TODO, all models currently add a key frame, but future models may not, update here if this changes. + desired_num_frames += 1 # add one for the key frame that is always added + + self.num_frames = desired_num_frames + + + # Always stretch/shrink to the requested number of frames if needed + if self.dataset_config.shrink_video_to_frames or total_frames < self.num_frames: + # Distribute frames evenly across the entire video + interval = max_frame_index / (self.num_frames - 1) if self.num_frames > 1 else 0 + frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.num_frames)] + else: + # Calculate frame interval based on FPS ratio + fps_ratio = video_fps / self.dataset_config.fps + frame_interval = max(1, int(round(fps_ratio))) + + # Calculate max consecutive frames we can extract at desired FPS + max_consecutive_frames = (total_frames // frame_interval) + + if max_consecutive_frames < self.num_frames: + # Not enough frames at desired FPS, so stretch instead + interval = max_frame_index / (self.num_frames - 1) if self.num_frames > 1 else 0 + frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.num_frames)] + else: + # Calculate max start frame to ensure we can get all num_frames + max_start_frame = max_frame_index - ((self.num_frames - 1) * frame_interval) + start_frame = random.randint(0, max(0, max_start_frame)) + + # Generate list of frames to extract + frames_to_extract = [start_frame + (i * frame_interval) for i in range(self.num_frames)] + + # Final safety check - ensure no frame exceeds max valid index + frames_to_extract = [min(frame_idx, max_frame_index) for frame_idx in frames_to_extract] + + # Only log frames to extract if in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f" Frames to extract: {frames_to_extract}") + + # Extract frames + frames = [] + for frame_idx in frames_to_extract: + # Safety check - ensure frame_idx is within bounds (silently fix) + if frame_idx > max_frame_index: + frame_idx = max_frame_index + + # Set frame position + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + + # Silently verify position was set correctly (no warnings unless debug mode) + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + actual_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + if actual_pos != frame_idx: + print_acc(f"Warning: Failed to set exact frame position. Requested: {frame_idx}, Actual: {actual_pos}") + + ret, frame = cap.read() + if not ret: + # Try to provide more detailed error information + actual_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + frame_pos_info = f"Requested frame: {frame_idx}, Actual frame position: {actual_frame}" + + # Try to read the next available frame as a fallback + fallback_success = False + for fallback_offset in [1, -1, 5, -5, 10, -10]: + fallback_pos = max(0, min(frame_idx + fallback_offset, max_frame_index)) + cap.set(cv2.CAP_PROP_POS_FRAMES, fallback_pos) + fallback_ret, fallback_frame = cap.read() + if fallback_ret: + # Only log in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Falling back to nearby frame {fallback_pos} instead of {frame_idx}") + frame = fallback_frame + fallback_success = True + break + else: + # No fallback worked, raise a more detailed exception + video_info = f"Video: {self.path}, Total frames: {total_frames}, FPS: {video_fps}" + raise Exception(f"Failed to read frame {frame_idx} from video. {frame_pos_info}. {video_info}") + + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Convert to PIL Image + img = Image.fromarray(frame) + + # Apply the same processing as for single images + img = img.convert('RGB') + + if self.flip_x: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + # Apply bucketing + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + + # Apply transform if provided + if transform: + img = transform(img) + + frames.append(img) + + # Release the video capture + cap.release() + + # Stack frames into tensor [frames, channels, height, width] + self.tensor = torch.stack(frames) + + # ------------------------------ + # Audio extraction + stretching + # ------------------------------ + if do_audio: + # Default to "no audio" unless we successfully extract it + self.audio_data = None + self.audio_tensor = None + + try: + import torchaudio + import torch.nn.functional as F + + # Compute the time range of the selected frames in the *source* video + # Include the last frame by extending to the next frame boundary. + if video_fps and video_fps > 0 and len(frames_to_extract) > 0: + clip_start_frame = int(frames_to_extract[0]) + clip_end_frame = int(frames_to_extract[-1]) + clip_start_time = clip_start_frame / float(video_fps) + clip_end_time = (clip_end_frame + 1) / float(video_fps) + source_duration = max(0.0, clip_end_time - clip_start_time) + else: + clip_start_time = 0.0 + clip_end_time = 0.0 + source_duration = 0.0 + + # Target duration is how this sampled/stretched clip is interpreted for training + # (i.e. num_frames at the configured dataset FPS). + if hasattr(self.dataset_config, "fps") and self.dataset_config.fps and self.dataset_config.fps > 0: + target_duration = float(self.num_frames) / float(self.dataset_config.fps) + else: + target_duration = source_duration + + waveform, sample_rate = torchaudio.load(self.path) # [channels, samples] + + waveform = waveform_to_stereo(waveform) # Convert to stereo if not already + + if self.dataset_config.audio_normalize: + peak = waveform.abs().amax() # global peak across channels + eps = 1e-9 + target_peak = 0.999 # ~ -0.01 dBFS + gain = target_peak / (peak + eps) + waveform = waveform * gain + + # Slice to the selected clip region (when we have a meaningful time range) + if source_duration > 0.0: + start_sample = int(round(clip_start_time * sample_rate)) + end_sample = int(round(clip_end_time * sample_rate)) + start_sample = max(0, min(start_sample, waveform.shape[-1])) + end_sample = max(0, min(end_sample, waveform.shape[-1])) + if end_sample > start_sample: + waveform = waveform[..., start_sample:end_sample] + else: + # No valid audio segment + waveform = None + else: + # If we can't compute a meaningful time range, treat as no-audio + waveform = None + + if waveform is not None and waveform.numel() > 0: + target_samples = int(round(target_duration * sample_rate)) + if target_samples > 0 and waveform.shape[-1] != target_samples: + # Time-stretch/shrink to match the video clip duration implied by dataset FPS. + if self.dataset_config.audio_preserve_pitch: + waveform = time_stretch_preserve_pitch(waveform, sample_rate, target_samples) # waveform is [C, L] + else: + # Use linear interpolation over the time axis. + wf = waveform.unsqueeze(0) # [1, C, L] + wf = F.interpolate(wf, size=target_samples, mode="linear", align_corners=False) + waveform = wf.squeeze(0) # [C, L] + + self.audio_tensor = waveform + self.audio_data = {"waveform": waveform, "sample_rate": int(sample_rate)} + + except Exception as e: + # if issue with libtorchcodec "Could not load libtorchcodec" + raise Exception(f"** WARNING ** - Error Processing audio for {self.path}. Error: {e}") + + # Only log success in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Successfully loaded video with {len(frames)} frames: {self.path}") + + except Exception as e: + # Print full traceback + traceback.print_exc() + + # Provide more context about the error + error_msg = str(e) + try: + if 'Failed to read frame' in error_msg and cap is not None: + # Try to get more info about the video that failed + cap_status = "Opened" if cap.isOpened() else "Closed" + current_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if cap.isOpened() else "Unknown" + reported_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if cap.isOpened() else "Unknown" + + print_acc(f"Video details when error occurred:") + print_acc(f" Cap status: {cap_status}") + print_acc(f" Current position: {current_pos}") + print_acc(f" Reported total frames: {reported_total}") + + # Try to verify if the video is corrupted + if cap.isOpened(): + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Go to start + start_ret, _ = cap.read() + + # Try to read the last frame to check if it's accessible + if reported_total > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, reported_total - 1) + end_ret, _ = cap.read() + print_acc(f" Can read first frame: {start_ret}, Can read last frame: {end_ret}") + + # Close the cap if it's still open + cap.release() + except Exception as debug_err: + print_acc(f"Error during error diagnosis: {debug_err}") + + print_acc(f"Error: {error_msg}") + print_acc(f"Error loading video: {self.path}") + + # Re-raise with more detailed information + raise Exception(f"Video loading error ({self.path}): {error_msg}") from e + + def load_and_process_image( + self: 'FileItemDTO', + transform: Union[None, transforms.Compose], + only_load_latents=False + ): + # handle get_prompt_embedding + if self.is_text_embedding_cached: + self.load_prompt_embedding() + # if we are caching latents, just do that + if self.is_latent_cached: + self.get_latent() + # if load_image_when_caching_latents is set, we still need the raw image + # tensor in addition to the cached latent, so fall through to load it below + if not self.dataset_config.load_image_when_caching_latents: + if self.has_control_image: + self.load_control_image() + if self.has_inpaint_image: + self.load_inpaint_image() + if self.has_clip_image: + self.load_clip_image() + if self.has_mask_image: + self.load_mask_image() + if self.has_unconditional: + self.load_unconditional_image() + return + if self.is_audio_model: + self.load_and_process_audio() + return + if self.dataset_config.num_frames > 1 or self.dataset_config.auto_frame_count: + self.load_and_process_video(transform, only_load_latents) + return + try: + img = Image.open(self.path) + img = exif_transpose(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.path}") + + if self.use_alpha_as_mask: + # we do this to make sure it does not replace the alpha with another color + # we want the image just without the alpha channel + np_img = np.array(img) + # strip off alpha + np_img = np_img[:, :, :3] + img = Image.fromarray(np_img) + + img = img.convert('RGB') + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + print_acc( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + print_acc( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height + if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: + # todo look into this. This still happens sometimes + print_acc('size mismatch') + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + else: + # Downscale the source image first + # TODO this is nto right + img = img.resize( + (int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)), + Image.BICUBIC) + min_img_size = min(img.size) + if self.dataset_config.random_crop: + if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution: + if min_img_size < self.dataset_config.resolution: + print_acc( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}") + scale_size = self.dataset_config.resolution + else: + scale_size = random.randint(self.dataset_config.resolution, int(min_img_size)) + scaler = scale_size / min_img_size + scale_width = int((img.width + 5) * scaler) + scale_height = int((img.height + 5) * scaler) + img = img.resize((scale_width, scale_height), Image.BICUBIC) + img = transforms.RandomCrop(self.dataset_config.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC) + + if self.augments is not None and len(self.augments) > 0: + # do augmentations + for augment in self.augments: + if augment in transforms_dict: + img = transforms_dict[augment](img) + + if self.has_augmentations: + # augmentations handles transforms + img = self.augment_image(img, transform=transform) + elif transform: + img = transform(img) + + self.tensor = img + if not only_load_latents: + if self.has_control_image: + self.load_control_image() + if self.has_inpaint_image: + self.load_inpaint_image() + if self.has_clip_image: + self.load_clip_image() + if self.has_mask_image: + self.load_mask_image() + if self.has_unconditional: + self.load_unconditional_image() + + +class InpaintControlFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_inpaint_image = False + self.inpaint_path: Union[str, None] = None + self.inpaint_tensor: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.inpaint_path is not None: + # find the control image path + inpaint_path = dataset_config.inpaint_path + # we are using control images + img_path = kwargs.get('path', None) + img_inpaint_ext_list = ['.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + + for ext in img_inpaint_ext_list: + p = os.path.join(inpaint_path, file_name_no_ext + ext) + if os.path.exists(p): + self.inpaint_path = p + self.has_inpaint_image = True + break + + def load_inpaint_image(self: 'FileItemDTO'): + try: + # image must have alpha channel for inpaint + img = Image.open(self.inpaint_path) + # make sure has aplha + if img.mode != 'RGBA': + return + img = exif_transpose(img) + + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Inpaint images not supported for non-bucket datasets") + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + tensor = self.augment_spatial_control(img, transform=transform) + else: + tensor = transform(img) + + # is 0 to 1 with alpha + self.inpaint_tensor = tensor + + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.inpaint_path}") + + + def cleanup_inpaint(self: 'FileItemDTO'): + self.inpaint_tensor = None + + +class ControlFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_control_image = False + self.control_path: Union[str, List[str], None] = None + self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[torch.Tensor], None] = None + sd = kwargs.get('sd', None) + self.use_raw_control_images = sd is not None and sd.use_raw_control_images + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.full_size_control_images = False + if dataset_config.control_path is not None: + # find the control image path + control_path_list = dataset_config.control_path + if not isinstance(control_path_list, list): + control_path_list = [control_path_list] + self.full_size_control_images = dataset_config.full_size_control_images + # we are using control images + img_path = kwargs.get('path', None) + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + + found_control_images = [] + for control_path in control_path_list: + for ext in img_ext_list: + if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): + found_control_images.append(os.path.join(control_path, file_name_no_ext + ext)) + self.has_control_image = True + break + self.control_path = found_control_images + if len(self.control_path) == 0: + self.control_path = None + elif len(self.control_path) == 1: + # only do one + self.control_path = self.control_path[0] + + def load_control_image(self: 'FileItemDTO'): + control_tensors = [] + control_path_list = self.control_path + if not isinstance(self.control_path, list): + control_path_list = [self.control_path] + + for control_path in control_path_list: + try: + img = Image.open(control_path) + img = exif_transpose(img) + + if img.mode in ("RGBA", "LA"): + # Create a background with the specified transparent color + transparent_color = tuple(self.dataset_config.control_transparent_color) + background = Image.new("RGB", img.size, transparent_color) + # Paste the image on top using its alpha channel as mask + background.paste(img, mask=img.getchannel("A")) + img = background + else: + # Already no alpha channel + img = img.convert("RGB") + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {control_path}") + + if not self.full_size_control_images: + # we just scale them to 512x512: + w, h = img.size + img = img.resize((512, 512), Image.BICUBIC) + + elif not self.use_raw_control_images: + w, h = img.size + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Control images not supported for non-bucket datasets") + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + tensor = self.augment_spatial_control(img, transform=transform) + else: + tensor = transform(img) + control_tensors.append(tensor) + + if len(control_tensors) == 0: + self.control_tensor = None + elif len(control_tensors) == 1: + self.control_tensor = control_tensors[0] + elif self.use_raw_control_images: + # just send the list of tensors as their shapes wont match + self.control_tensor_list = control_tensors + else: + self.control_tensor = torch.stack(control_tensors, dim=0) + + def cleanup_control(self: 'FileItemDTO'): + self.control_tensor = None + self.control_tensor_list = None + + +class ClipImageFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_clip_image = False + self.clip_image_path: Union[str, None] = None + self.clip_image_tensor: Union[torch.Tensor, None] = None + self.clip_image_embeds: Union[dict, None] = None + self.clip_image_embeds_unconditional: Union[dict, None] = None + self.has_clip_augmentations = False + self.clip_image_aug_transform: Union[None, A.Compose] = None + self.clip_image_processor: Union[None, CLIPImageProcessor] = None + self.clip_image_encoder_path: Union[str, None] = None + self.is_caching_clip_vision_to_disk = False + self.is_vision_clip_cached = False + self.clip_vision_is_quad = False + self.clip_vision_load_device = 'cpu' + self.clip_vision_unconditional_paths: Union[List[str], None] = None + self._clip_vision_embeddings_path: Union[str, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder: + # copy the clip image processor so the dataloader can do it + sd = kwargs.get('sd', None) + if hasattr(sd.adapter, 'clip_image_processor'): + self.clip_image_processor = sd.adapter.clip_image_processor + if dataset_config.clip_image_path is not None: + # find the control image path + clip_image_path = dataset_config.clip_image_path + # we are using control images + img_path = kwargs.get('path', None) + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)): + self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext) + self.has_clip_image = True + break + self.build_clip_imag_augmentation_transform() + + if dataset_config.clip_image_from_same_folder: + # assume we have one. We will pull it on load. + self.has_clip_image = True + self.build_clip_imag_augmentation_transform() + + def build_clip_imag_augmentation_transform(self: 'FileItemDTO'): + if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0: + self.has_clip_augmentations = True + augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations] + + if self.dataset_config.clip_image_shuffle_augmentations: + random.shuffle(augmentations) + + augmentation_list = [] + for aug in augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.clip_image_aug_transform = A.Compose(augmentation_list) + + def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): + if self.dataset_config.clip_image_shuffle_augmentations: + self.build_clip_imag_augmentation_transform() + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + if self.clip_vision_is_quad: + # image is in a 2x2 gris. split, run augs, and recombine + # split + img1, img2 = np.hsplit(open_cv_image, 2) + img1_1, img1_2 = np.vsplit(img1, 2) + img2_1, img2_2 = np.vsplit(img2, 2) + # apply augmentations + img1_1 = self.clip_image_aug_transform(image=img1_1)["image"] + img1_2 = self.clip_image_aug_transform(image=img1_2)["image"] + img2_1 = self.clip_image_aug_transform(image=img2_1)["image"] + img2_2 = self.clip_image_aug_transform(image=img2_2)["image"] + # recombine + augmented = np.vstack((np.hstack((img1_1, img1_2)), np.hstack((img2_1, img2_2)))) + + else: + # apply augmentations + augmented = self.clip_image_aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + + return augmented_tensor + + def get_clip_vision_info_dict(self: 'FileItemDTO'): + item = OrderedDict([ + ("image_encoder_path", self.clip_image_encoder_path), + ("filename", os.path.basename(self.clip_image_path)), + ("is_quad", self.clip_vision_is_quad) + ]) + # when adding items, do it after so we dont change old latents + if self.flip_x: + item["flip_x"] = True + if self.flip_y: + item["flip_y"] = True + return item + def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False): + if self._clip_vision_embeddings_path is not None and not recalculate: + return self._clip_vision_embeddings_path + else: + # we store latents in a folder in same path as image called _latent_cache + img_dir = os.path.dirname(self.clip_image_path) + latent_dir = os.path.join(img_dir, '_clip_vision_cache') + hash_dict = self.get_clip_vision_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._clip_vision_embeddings_path + + def get_new_clip_image_path(self: 'FileItemDTO'): + if self.dataset_config.clip_image_from_same_folder: + # randomly grab an image path from the same folder + pool_folder = os.path.dirname(self.path) + # find all images in the folder + img_files = [] + for ext in img_ext_list: + img_files += glob.glob(os.path.join(pool_folder, f'*{ext}')) + # remove the current image if len is greater than 1 + if len(img_files) > 1: + img_files.remove(self.path) + # randomly grab one + return random.choice(img_files) + else: + return self.clip_image_path + + def load_clip_image(self: 'FileItemDTO'): + is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \ + isinstance(self.clip_image_processor, SiglipImageProcessor) + if self.clip_image_processor is None: + is_dynamic_size_and_aspect = True # serving it raw + if self.is_vision_clip_cached: + self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path()) + + # get a random unconditional image + if self.clip_vision_unconditional_paths is not None: + unconditional_path = random.choice(self.clip_vision_unconditional_paths) + self.clip_image_embeds_unconditional = load_file(unconditional_path) + + return + clip_image_path = self.get_new_clip_image_path() + try: + img = Image.open(clip_image_path).convert('RGB') + img = exif_transpose(img) + except Exception as e: + # make a random noise image + img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution)) + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {clip_image_path}") + + img = img.convert('RGB') + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if is_dynamic_size_and_aspect: + pass # let the image processor handle it + elif img.width != img.height: + min_size = min(img.width, img.height) + if self.dataset_config.square_crop: + # center crop to a square + img = transforms.CenterCrop(min_size)(img) + else: + # image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data + # resize to the smallest dimension + img = img.resize((min_size, min_size), Image.BICUBIC) + + if self.has_clip_augmentations: + self.clip_image_tensor = self.augment_clip_image(img, transform=None) + else: + self.clip_image_tensor = transforms.ToTensor()(img) + + # random crop + # if self.dataset_config.clip_image_random_crop: + # # crop up to 20% on all sides. Keep is square + # crop_percent = random.randint(0, 20) / 100 + # crop_width = int(self.clip_image_tensor.shape[2] * crop_percent) + # crop_height = int(self.clip_image_tensor.shape[1] * crop_percent) + # crop_left = random.randint(0, crop_width) + # crop_top = random.randint(0, crop_height) + # crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left + # crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top + # if len(self.clip_image_tensor.shape) == 3: + # self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right] + # elif len(self.clip_image_tensor.shape) == 4: + # self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right] + + if self.clip_image_processor is not None: + # run it + tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16) + clip_out = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + self.clip_image_tensor = clip_out.squeeze(0).clone().detach() + + def cleanup_clip_image(self: 'FileItemDTO'): + self.clip_image_tensor = None + self.clip_image_embeds = None + + + + +class AugmentationFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_augmentations = False + self.unaugmented_tensor: Union[torch.Tensor, None] = None + # self.augmentations: Union[None, List[Augments]] = None + self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.aug_transform: Union[None, A.Compose] = None + self.aug_replay_spatial_transforms = None + self.build_augmentation_transform() + + def build_augmentation_transform(self: 'FileItemDTO'): + if self.dataset_config.augmentations is not None and len(self.dataset_config.augmentations) > 0: + self.has_augmentations = True + augmentations = [Augments(**aug) for aug in self.dataset_config.augmentations] + + if self.dataset_config.shuffle_augmentations: + random.shuffle(augmentations) + + augmentation_list = [] + for aug in augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + # add additional targets so we can augment the control image + self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'}) + + def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): + + # rebuild each time if shuffle + if self.dataset_config.shuffle_augmentations: + self.build_augmentation_transform() + + # save the original tensor + self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img) + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + transformed = self.aug_transform(image=open_cv_image) + augmented = transformed["image"] + + # save just the spatial transforms for controls and masks + augmented_params = transformed["replay"] + spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop', + 'ElasticTransform', 'GridDistortion', 'OpticalDistortion'] + # only store the spatial transforms + augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms] + + if self.dataset_config.replay_transforms: + self.aug_replay_spatial_transforms = augmented_params + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + + return augmented_tensor + + # augment control images spatially consistent with transforms done to the main image + def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ): + if self.aug_replay_spatial_transforms is None: + # no transforms + return transform(img) + + # save colorspace to convert back to + colorspace = img.mode + + # convert to rgb + img = img.convert('RGB') + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # Replay transforms + transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image) + augmented = transformed["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + # convert back to original colorspace + augmented = augmented.convert(colorspace) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + return augmented_tensor + + def cleanup_control(self: 'FileItemDTO'): + self.unaugmented_tensor = None + + +class MaskFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_mask_image = False + self.mask_path: Union[str, None] = None + self.mask_tensor: Union[torch.Tensor, None] = None + self.use_alpha_as_mask: bool = False + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.mask_min_value = dataset_config.mask_min_value + if dataset_config.alpha_mask: + self.use_alpha_as_mask = True + self.mask_path = kwargs.get('path', None) + self.has_mask_image = True + elif dataset_config.mask_path is not None: + # find the control image path + mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask + # we are using control images + img_path = kwargs.get('path', None) + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)): + self.mask_path = os.path.join(mask_path, file_name_no_ext + ext) + self.has_mask_image = True + break + + def load_mask_image(self: 'FileItemDTO'): + try: + img = Image.open(self.mask_path) + img = exif_transpose(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.mask_path}") + + if self.use_alpha_as_mask: + # pipeline expectws an rgb image so we need to put alpha in all channels + np_img = np.array(img) + np_img[:, :, :3] = np_img[:, :, 3:] + + np_img = np_img[:, :, :3] + img = Image.fromarray(np_img) + + img = img.convert('RGB') + if self.dataset_config.invert_mask: + img = ImageOps.invert(img) + w, h = img.size + fix_size = False + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + fix_size = True + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + fix_size = True + + if fix_size: + # swap all the sizes + self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width + self.crop_width, self.crop_height = self.crop_height, self.crop_width + self.crop_x, self.crop_y = self.crop_y, self.crop_x + + + + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + # randomly apply a blur up to 0.5% of the size of the min (width, height) + min_size = min(img.width, img.height) + blur_radius = int(min_size * random.random() * 0.005) + img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) + + # make grayscale + img = img.convert('L') + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Mask images not supported for non-bucket datasets") + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + self.mask_tensor = self.augment_spatial_control(img, transform=transform) + else: + self.mask_tensor = transform(img) + self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0) + # convert to grayscale + + def cleanup_mask(self: 'FileItemDTO'): + self.mask_tensor = None + + +class UnconditionalFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_unconditional = False + self.unconditional_path: Union[str, None] = None + self.unconditional_tensor: Union[torch.Tensor, None] = None + self.unconditional_latent: Union[torch.Tensor, None] = None + self.unconditional_transforms = self.dataloader_transforms + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + + if dataset_config.unconditional_path is not None: + # we are using control images + img_path = kwargs.get('path', None) + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)): + self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext) + self.has_unconditional = True + break + + def load_unconditional_image(self: 'FileItemDTO'): + try: + img = Image.open(self.unconditional_path) + img = exif_transpose(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.mask_path}") + + img = img.convert('RGB') + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Unconditional images are not supported for non-bucket datasets") + + if self.aug_replay_spatial_transforms: + self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms) + else: + self.unconditional_tensor = self.unconditional_transforms(img) + + def cleanup_unconditional(self: 'FileItemDTO'): + self.unconditional_tensor = None + self.unconditional_latent = None + +class ArgBreakMixin: + # just stops super calls form hitting object + def __init__(self, *args, **kwargs): + pass + + +class LatentCachingFileItemDTOMixin: + def __init__(self, *args, **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self._encoded_latent: Union[torch.Tensor, None] = None + self._cached_first_frame_latent: Union[torch.Tensor, None] = None + self._cached_audio_latent: Union[torch.Tensor, None] = None + self._latent_path: Union[str, None] = None + self.is_latent_cached = False + self.is_caching_to_disk = False + self.is_caching_to_memory = False + self.latent_load_device = 'cpu' + # todo, increment this if we change the latent format to invalidate cache + self.latent_version = 1 + + def get_latent_info_dict(self: 'FileItemDTO'): + item = OrderedDict([ + ("filename", os.path.basename(self.path)), + ("scale_to_width", self.scale_to_width), + ("scale_to_height", self.scale_to_height), + ("crop_x", self.crop_x), + ("crop_y", self.crop_y), + ("crop_width", self.crop_width), + ("crop_height", self.crop_height), + ("latent_space_version", self.latent_space_version), + ("latent_version", self.latent_version), + ]) + is_video = False + # when adding items, do it after so we dont change old latents + if self.flip_x: + item["flip_x"] = True + if self.flip_y: + item["flip_y"] = True + if self.dataset_config.auto_frame_count: + # don't store num frames here as it is calculated dynamically + item["auto_frame_count"] = True + is_video = True + elif self.dataset_config.num_frames > 1: + item["num_frames"] = self.dataset_config.num_frames + is_video = True + if is_video and self.dataset_config.fps != 24: + # only add fps if it deviates from the default + item["fps"] = self.dataset_config.fps + if is_video and self.dataset_config.do_i2v: + item["do_i2v"] = True + if is_video and self.dataset_config.do_audio: + item["do_audio"] = True + if self.dataset_config.audio_normalize: + item["audio_normalize"] = True + if self.dataset_config.audio_preserve_pitch: + item["audio_preserve_pitch"] = True + if self.is_audio_model: + item["is_audio_model"] = True + item["sample_rate"] = self.sample_rate + return item + + def get_latent_path(self: 'FileItemDTO', recalculate=False): + if self._latent_path is not None and not recalculate: + return self._latent_path + else: + # we store latents in a folder in same path as image called _latent_cache + img_dir = os.path.dirname(self.path) + latent_dir = os.path.join(img_dir, '_latent_cache') + hash_dict = self.get_latent_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._latent_path + + def cleanup_latent(self): + if self._encoded_latent is not None: + if not self.is_caching_to_memory: + # we are caching on disk, don't save in memory + self._encoded_latent = None + self._cached_first_frame_latent = None + self._cached_audio_latent = None + else: + # move it back to cpu + self._encoded_latent = self._encoded_latent.to('cpu') + if self._cached_first_frame_latent is not None: + self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu') + if self._cached_audio_latent is not None: + self._cached_audio_latent = self._cached_audio_latent.to('cpu') + + def get_latent(self, device=None): + if not self.is_latent_cached: + return None + if self._encoded_latent is None: + # load it from disk + state_dict = load_file( + self.get_latent_path(), + # device=device if device is not None else self.latent_load_device + device='cpu' + ) + self._encoded_latent = state_dict['latent'] + if 'first_frame_latent' in state_dict: + self._cached_first_frame_latent = state_dict['first_frame_latent'] + if 'audio_latent' in state_dict: + self._cached_audio_latent = state_dict['audio_latent'] + if 'num_frames' in state_dict: + self.num_frames = int(state_dict['num_frames'].item()) + return self._encoded_latent + + +class LatentCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.latent_cache = {} + + def cache_latents_all_latents(self: 'AiToolkitDataset'): + with accelerator.main_process_first(): + print_acc(f"Caching latents for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory + + if to_disk: + print_acc(" - Saving latents to disk") + if to_memory: + print_acc(" - Keeping latents in memory") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device + + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + if 'first_frame_latent' in state_dict: + file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype) + if 'audio_latent' in state_dict: + file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform, only_load_latents=True) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + state_dict = OrderedDict() + first_frame_latent = None + audio_latent = None + frames = None + # add batch dimension + try: + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + if to_disk: + state_dict['latent'] = latent.clone().detach().cpu() + except Exception as e: + print_acc(f"Error processing image: {file_item.path}") + print_acc(f"Error: {str(e)}") + raise e + # do first frame + is_video = self.dataset_config.auto_frame_count or self.dataset_config.num_frames > 1 + if is_video and self.dataset_config.do_i2v: + frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + first_frame_latent = self.sd.encode_images(first_frames).squeeze(0) + if to_disk: + state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu() + + # audio (video+audio models only — audio-only models already encoded above via encode_images) + if not self.is_audio_model and file_item.audio_data is not None: + audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0) + if to_disk: + state_dict['audio_latent'] = audio_latent.clone().detach().cpu() + + if is_video: + state_dict['num_frames'] = torch.tensor(file_item.num_frames, dtype=torch.int32) + + # save_latent + if to_disk: + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) + + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + if first_frame_latent is not None: + file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype) + if audio_latent is not None: + file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype) + + del imgs + del latent + del frames + del file_item.tensor + del state_dict + del first_frame_latent + del audio_latent + file_item.cleanup() + + file_item.is_latent_cached = True + i += 1 + + # restore device state + self.sd.restore_device_state() + + +class TextEmbeddingFileItemDTOMixin: + def __init__(self, *args, **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.prompt_embeds: Union[PromptEmbeds, None] = None + self._text_embedding_path: Union[str, None] = None + self.is_text_embedding_cached = False + self.text_embedding_load_device = 'cpu' + self.text_embedding_version = 1 + + def get_text_embedding_info_dict(self: 'FileItemDTO'): + # make sure the caption is loaded here + # TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible. + if self.caption is None: + self.load_caption() + item = OrderedDict([ + ("caption", self.caption), + ("text_embedding_space_version", self.text_embedding_space_version), + ("text_embedding_version", self.text_embedding_version), + ]) + # if we have a control image, cache the path + if self.encode_control_in_text_embeddings and self.control_path is not None: + item["control_path"] = self.control_path + return item + + def get_text_embedding_path(self: 'FileItemDTO', recalculate=False): + if self._text_embedding_path is not None and not recalculate: + return self._text_embedding_path + else: + # we store text embeddings in a folder in same path as image called _text_embedding_cache + img_dir = os.path.dirname(self.path) + te_dir = os.path.join(img_dir, '_t_e_cache') + hash_dict = self.get_text_embedding_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._text_embedding_path + + def cleanup_text_embedding(self): + if self.prompt_embeds is not None: + # we are caching on disk, don't save in memory + self.prompt_embeds = None + + def load_prompt_embedding(self, device=None): + if not self.is_text_embedding_cached: + return + if self.prompt_embeds is None: + # load it from disk + self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path()) + +class TextEmbeddingCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + + def cache_text_embeddings(self: 'AiToolkitDataset'): + with accelerator.main_process_first(): + print_acc(f"Caching text_embeddings for {self.dataset_path}") + print_acc(" - Saving text embeddings to disk") + + did_move = False + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'): + file_item.latent_load_device = self.sd.device + + text_embedding_path = file_item.get_text_embedding_path(recalculate=True) + # only process if not saved to disk + if not os.path.exists(text_embedding_path): + # load if not loaded + if not did_move: + self.sd.set_device_state_preset('cache_text_encoder') + did_move = True + + if file_item.encode_control_in_text_embeddings: + if file_item.control_path is None: + raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") + ctrl_img_list = [] + control_path_list = file_item.control_path + if not isinstance(file_item.control_path, list): + control_path_list = [control_path_list] + for i in range(len(control_path_list)): + try: + img = Image.open(control_path_list[i]).convert("RGB") + img = exif_transpose(img) + # convert to 0 to 1 tensor + img = ( + TF.to_tensor(img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading control image: {control_path_list[i]}") + + if len(ctrl_img_list) == 0: + ctrl_img = None + elif not self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list[0] + else: + ctrl_img = ctrl_img_list + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) + else: + prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) + # save it + prompt_embeds.save(text_embedding_path) + del prompt_embeds + file_item.is_text_embedding_cached = True + i += 1 + # restore device state + # if did_move: + # self.sd.restore_device_state() + + +class CLIPCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.clip_vision_num_unconditional_cache = 20 + self.clip_vision_unconditional_cache = [] + + def cache_clip_vision_to_disk(self: 'AiToolkitDataset'): + if not self.is_caching_clip_vision_to_disk: + return + with torch.no_grad(): + print_acc(f"Caching clip vision for {self.dataset_path}") + + print_acc(" - Saving clip to disk") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_clip') + + # make sure the adapter has attributes + if self.sd.adapter is None: + raise Exception("Error: must have an adapter to cache clip vision to disk") + + clip_image_processor: CLIPImageProcessor = None + if hasattr(self.sd.adapter, 'clip_image_processor'): + clip_image_processor = self.sd.adapter.clip_image_processor + + if clip_image_processor is None: + raise Exception("Error: must have a clip image processor to cache clip vision to disk") + + vision_encoder: CLIPVisionModelWithProjection = None + if hasattr(self.sd.adapter, 'image_encoder'): + vision_encoder = self.sd.adapter.image_encoder + if hasattr(self.sd.adapter, 'vision_encoder'): + vision_encoder = self.sd.adapter.vision_encoder + + if vision_encoder is None: + raise Exception("Error: must have a vision encoder to cache clip vision to disk") + + # move vision encoder to device + vision_encoder.to(self.sd.device) + + is_quad = self.sd.adapter.config.quad_image + image_encoder_path = self.sd.adapter.config.image_encoder_path + + dtype = self.sd.torch_dtype + device = self.sd.device_torch + if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero: + # just to do this, we did :) + # need more samples as it is random noise + self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache + else: + # only need one since it doesnt change + self.clip_vision_num_unconditional_cache = 1 + + # cache unconditionals + print_acc(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") + clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache') + + unconditional_paths = [] + + is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero + + for i in range(self.clip_vision_num_unconditional_cache): + hash_dict = OrderedDict([ + ("image_encoder_path", image_encoder_path), + ("is_quad", is_quad), + ("is_noise_zero", is_noise_zero), + ]) + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + + uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors') + if os.path.exists(uncond_path): + # skip it + unconditional_paths.append(uncond_path) + continue + + # generate a random image + img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size) + if is_noise_zero: + tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32) + else: + tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32) + clip_image = clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + if is_quad: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach() + + clip_output = vision_encoder( + clip_image.to(device, dtype=dtype), + output_hidden_states=True + ) + # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + state_dict = OrderedDict([ + ('image_embeds', clip_output.image_embeds.clone().detach().cpu()), + ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()), + ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()), + ]) + + os.makedirs(os.path.dirname(uncond_path), exist_ok=True) + save_file(state_dict, uncond_path) + unconditional_paths.append(uncond_path) + + self.clip_vision_unconditional_cache = unconditional_paths + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'): + file_item.is_caching_clip_vision_to_disk = True + file_item.clip_vision_load_device = self.sd.device + file_item.clip_vision_is_quad = is_quad + file_item.clip_image_encoder_path = image_encoder_path + file_item.clip_vision_unconditional_paths = unconditional_paths + if file_item.has_clip_augmentations: + raise Exception("Error: clip vision caching is not supported with clip augmentations") + + embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True) + # check if it is saved to disk already + if not os.path.exists(embedding_path): + # load the image first + file_item.load_clip_image() + # add batch dimension + clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype) + + if is_quad: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach() + + clip_output = vision_encoder( + clip_image.to(device, dtype=dtype), + output_hidden_states=True + ) + + # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + state_dict = OrderedDict([ + ('image_embeds', clip_output.image_embeds.clone().detach().cpu()), + ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()), + ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict()) + os.makedirs(os.path.dirname(embedding_path), exist_ok=True) + save_file(state_dict, embedding_path, metadata=meta) + + del clip_image + del clip_output + del file_item.clip_image_tensor + + # flush(garbage_collect=False) + file_item.is_vision_clip_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() + + # restore device state + self.sd.restore_device_state() + + + +class ControlCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.control_generator: ControlGenerator = None + + def add_control_path_to_file_item(self: 'AiToolkitDataset', file_item: 'FileItemDTO', control_path: str, control_type: ControlTypes): + if control_type == 'inpaint': + file_item.inpaint_path = control_path + file_item.has_inpaint_image = True + elif control_type == 'mask' or control_type == 'sapiens2_mask': + file_item.mask_path = control_path + file_item.has_mask_image = True + else: + if file_item.control_path is None: + file_item.control_path = [control_path] + elif isinstance(file_item.control_path, str): + file_item.control_path = [file_item.control_path, control_path] + elif isinstance(file_item.control_path, list): + file_item.control_path.append(control_path) + else: + raise Exception(f"Error: control_path is not a string or list: {file_item.control_path}") + file_item.has_control_image = True + + def setup_controls(self: 'AiToolkitDataset'): + if not self.is_generating_controls: + return + with torch.no_grad(): + print_acc(f"Generating controls for {self.dataset_path}") + device = self.sd.device + + self.control_generator = ControlGenerator( + device=device, + sd=self.sd, + ) + + # use tqdm to show progress + for file_item in tqdm(self.file_list, desc=f'Generating Controls'): + for control_type in self.dataset_config.controls: + # generates the control if it is not already there + control_path = self.control_generator.get_control_path(file_item.path, control_type) + if control_path is not None: + self.add_control_path_to_file_item(file_item, control_path, control_type) + + # remove models + self.control_generator.cleanup() + self.control_generator = None + + flush() diff --git a/ai-toolkit/toolkit/dequantize.py b/ai-toolkit/toolkit/dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..54c8ec7b29862efa11b7fc3c9dc1efc8c1d66423 --- /dev/null +++ b/ai-toolkit/toolkit/dequantize.py @@ -0,0 +1,88 @@ + + +from functools import partial +from optimum.quanto.tensor import QTensor +import torch + + +def hacked_state_dict(self, *args, **kwargs): + orig_state_dict = self.orig_state_dict(*args, **kwargs) + new_state_dict = {} + for key, value in orig_state_dict.items(): + if key.endswith("._scale"): + continue + if key.endswith(".input_scale"): + continue + if key.endswith(".output_scale"): + continue + if key.endswith("._data"): + key = key[:-6] + scale = orig_state_dict[key + "._scale"] + # scale is the original dtype + dtype = scale.dtype + scale = scale.float() + value = value.float() + dequantized = value * scale + + # handle input and output scaling if they exist + input_scale = orig_state_dict.get(key + ".input_scale") + + if input_scale is not None: + # make sure the tensor is 1.0 + if input_scale.item() != 1.0: + raise ValueError("Input scale is not 1.0, cannot dequantize") + + output_scale = orig_state_dict.get(key + ".output_scale") + + if output_scale is not None: + # make sure the tensor is 1.0 + if output_scale.item() != 1.0: + raise ValueError("Output scale is not 1.0, cannot dequantize") + + new_state_dict[key] = dequantized.to('cpu', dtype=dtype) + else: + new_state_dict[key] = value + return new_state_dict + +# hacks the state dict so we can dequantize before saving +def patch_dequantization_on_save(model): + model.orig_state_dict = model.state_dict + model.state_dict = partial(hacked_state_dict, model) + + +def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool: + """ + Convert a quantized parameter back to a regular Parameter with floating point values. + + Args: + module: The module containing the parameter to unquantize + param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias') + + Returns: + bool: True if parameter was unquantized, False if it was already unquantized + """ + + # Check if the parameter exists + if not hasattr(module, param_name): + raise AttributeError(f"Module has no parameter named '{param_name}'") + + param = getattr(module, param_name) + + # If it's not a parameter or not quantized, nothing to do + if not isinstance(param, torch.nn.Parameter): + raise TypeError(f"'{param_name}' is not a Parameter") + if not isinstance(param, QTensor): + return False + + # Convert to float tensor while preserving device and requires_grad + with torch.no_grad(): + float_tensor = param.float() + new_param = torch.nn.Parameter( + float_tensor, + requires_grad=param.requires_grad + ) + + # Replace the parameter + setattr(module, param_name, new_param) + + return True \ No newline at end of file diff --git a/ai-toolkit/toolkit/ema.py b/ai-toolkit/toolkit/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..b34554bb1dc4654c0793f8bd88140e035da7b45a --- /dev/null +++ b/ai-toolkit/toolkit/ema.py @@ -0,0 +1,347 @@ +from __future__ import division +from __future__ import unicode_literals + +from typing import Iterable, Optional +import weakref +import copy +import contextlib +from toolkit.optimizers.optimizer_utils import copy_stochastic + +import torch + + +# Partially based on: +# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + + decay: The exponential decay. + + use_num_updates: Whether to use number of updates when computing + averages. + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter] = None, + decay: float = 0.995, + use_num_updates: bool = False, + # feeds back the decat to the parameter + use_feedback: bool = False, + param_multiplier: float = 1.0 + ): + if parameters is None: + raise ValueError("parameters must be provided") + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.use_feedback = use_feedback + self.param_multiplier = param_multiplier + parameters = list(parameters) + self.shadow_params = [ + p.clone().detach() + for p in parameters + ] + self.collected_params = None + self._is_train_mode = True + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + + def _get_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this " + "ExponentialMovingAverage " + "was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep " + "the model to which they belong from being garbage " + "collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def update( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min( + decay, + (1 + self.num_updates) / (10 + self.num_updates) + ) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + s_param_float = s_param.float() + if s_param.dtype != torch.float32: + s_param_float = s_param_float.to(torch.float32) + param_float = param + if param.dtype != torch.float32: + param_float = param_float.to(torch.float32) + tmp = (s_param_float - param_float) + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param_float.sub_(tmp) + + update_param = False + if self.use_feedback: + # make feedback 10x decay + param_float.add_(tmp * 10) + update_param = True + + if self.param_multiplier != 1.0: + param_float.mul_(self.param_multiplier) + update_param = True + + if s_param.dtype != torch.float32: + copy_stochastic(s_param, s_param_float) + + if update_param and param.dtype != torch.float32: + copy_stochastic(param, param_float) + + + def copy_to( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [ + param.clone() + for param in parameters + ] + + def restore( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance(self.num_updates, int), \ + "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), \ + "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) + + def eval(self): + if self._is_train_mode: + with torch.no_grad(): + self.store() + self.copy_to() + self._is_train_mode = False + + def train(self): + if not self._is_train_mode: + with torch.no_grad(): + self.restore() + self._is_train_mode = True diff --git a/ai-toolkit/toolkit/embedding.py b/ai-toolkit/toolkit/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..94ba3f2f33bfa023f31da37f12c3ca4a34f0cc21 --- /dev/null +++ b/ai-toolkit/toolkit/embedding.py @@ -0,0 +1,284 @@ +import json +import os +from collections import OrderedDict + +import safetensors +import torch +from typing import TYPE_CHECKING + +from safetensors.torch import save_file + +from toolkit.metadata import get_meta_for_safetensors + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import EmbeddingConfig + + +# this is a frankenstein mix of automatic1111 and my own code + +class Embedding: + def __init__( + self, + sd: 'StableDiffusion', + embed_config: 'EmbeddingConfig', + state_dict: OrderedDict = None, + ): + self.name = embed_config.trigger + self.sd = sd + self.trigger = embed_config.trigger + self.embed_config = embed_config + self.step = 0 + # setup our embedding + # Add the placeholder token in tokenizer + placeholder_tokens = [self.embed_config.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.embed_config.tokens): + additional_tokens.append(f"{self.embed_config.trigger}_{i}") + placeholder_tokens += additional_tokens + + # handle dual tokenizer + self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer] + self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [ + self.sd.text_encoder] + + self.placeholder_token_ids = [] + self.embedding_tokens = [] + + print(f"Adding {placeholder_tokens} tokens to tokenizer") + print(f"Adding {self.embed_config.tokens} tokens to tokenizer") + + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.embed_config.tokens: + raise ValueError( + f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" + f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.embed_config.tokens: + init_token_ids = init_token_ids[:self.embed_config.tokens] + elif len(init_token_ids) < self.embed_config.tokens: + pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) + + placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) + self.placeholder_token_ids.append(placeholder_token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # replace "[name] with this. on training. This is automatically generated in pipeline on inference + self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) + + # backup text encoder embeddings + self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] + + def restore_embeddings(self): + with torch.no_grad(): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds[index_no_updates] + weight = text_encoder.get_input_embeddings().weight + pass + + def get_trainable_params(self): + params = [] + for text_encoder in self.text_encoder_list: + params += text_encoder.get_input_embeddings().parameters() + return params + + def _get_vec(self, text_encoder_idx=0): + # should we get params instead + # create vector from token embeds + token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data + # stack the tokens along batch axis adding that axis + new_vector = torch.stack( + [token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]], + dim=0 + ) + return new_vector + + def _set_vec(self, new_vector, text_encoder_idx=0): + # shape is (1, 768) for SD 1.5 for 1 token + token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data + for i in range(new_vector.shape[0]): + # apply the weights to the placeholder tokens while preserving gradient + token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone() + + # make setter and getter for vec + @property + def vec(self): + return self._get_vec(0) + + @vec.setter + def vec(self, new_vector): + self._set_vec(new_vector, 0) + + @property + def vec2(self): + return self._get_vec(1) + + @vec2.setter + def vec2(self, new_vector): + self._set_vec(new_vector, 1) + + # diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc + # however, on training we don't use that pipeline, so we have to do it ourselves + def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = ["[name]", "[trigger]"] + + replace_with = embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + + def state_dict(self): + if self.sd.is_xl: + state_dict = OrderedDict() + state_dict['clip_l'] = self.vec + state_dict['clip_g'] = self.vec2 + else: + state_dict = OrderedDict() + state_dict['emb_params'] = self.vec + + return state_dict + + def save(self, filename): + # todo check to see how to get the vector out of the embedding + + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + # todo get these + "sd_checkpoint": None, + "sd_checkpoint_name": None, + "notes": None, + } + # TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl + if filename.endswith('.pt'): + torch.save(embedding_data, filename) + elif filename.endswith('.bin'): + torch.save(embedding_data, filename) + elif filename.endswith('.safetensors'): + # save the embedding as a safetensors file + state_dict = self.state_dict() + # add all embedding data (except string_to_param), to metadata + metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"}) + metadata["string_to_param"] = {"*": "emb_params"} + save_meta = get_meta_for_safetensors(metadata, name=self.name) + save_file(state_dict, filename, metadata=save_meta) + + def load_embedding_from_file(self, file_path, device): + # full path + path = os.path.realpath(file_path) + filename = os.path.basename(path) + name, ext = os.path.splitext(filename) + tensors = {} + ext = ext.upper() + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + + if ext in ['.BIN', '.PT']: + # todo check this + if self.sd.is_xl: + raise Exception("XL not supported yet for bin, pt") + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + # rebuild the embedding from the safetensors file if it has it + with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + tensors[k] = f.get_tensor(k) + # data = safetensors.torch.load_file(path, device="cpu") + if metadata and 'string_to_param' in metadata and 'emb_params' in tensors: + # our format + def try_json(v): + try: + return json.loads(v) + except: + return v + + data = {k: try_json(v) for k, v in metadata.items()} + data['string_to_param'] = {'*': tensors['emb_params']} + else: + # old format + data = tensors + else: + return + + if self.sd.is_xl: + self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32) + self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32) + if 'step' in data: + self.step = int(data['step']) + else: + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, + '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception( + f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + if 'step' in data: + self.step = int(data['step']) + + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/ai-toolkit/toolkit/esrgan_utils.py b/ai-toolkit/toolkit/esrgan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25a8bfbada1bff84bc6bb1a49149d846c9c8c379 --- /dev/null +++ b/ai-toolkit/toolkit/esrgan_utils.py @@ -0,0 +1,51 @@ + +to_basicsr_dict = { + 'model.0.weight': 'conv_first.weight', + 'model.0.bias': 'conv_first.bias', + 'model.1.sub.23.weight': 'conv_body.weight', + 'model.1.sub.23.bias': 'conv_body.bias', + 'model.3.weight': 'conv_up1.weight', + 'model.3.bias': 'conv_up1.bias', + 'model.6.weight': 'conv_up2.weight', + 'model.6.bias': 'conv_up2.bias', + 'model.8.weight': 'conv_hr.weight', + 'model.8.bias': 'conv_hr.bias', + 'model.10.bias': 'conv_last.bias', + 'model.10.weight': 'conv_last.weight', + # 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight' +} + +def convert_state_dict_to_basicsr(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k in to_basicsr_dict: + new_state_dict[to_basicsr_dict[k]] = v + elif k.startswith('model.1.sub.'): + bsr_name = k.replace('model.1.sub.', 'body.').lower() + bsr_name = bsr_name.replace('.0.weight', '.weight') + bsr_name = bsr_name.replace('.0.bias', '.bias') + new_state_dict[bsr_name] = v + else: + new_state_dict[k] = v + return new_state_dict + + +# just matching a commonly used format +def convert_basicsr_state_dict_to_save_format(state_dict): + new_state_dict = {} + to_basicsr_dict_values = list(to_basicsr_dict.values()) + for k, v in state_dict.items(): + if k in to_basicsr_dict_values: + for key, value in to_basicsr_dict.items(): + if value == k: + new_state_dict[key] = v + + elif k.startswith('body.'): + bsr_name = k.replace('body.', 'model.1.sub.').lower() + bsr_name = bsr_name.replace('rdb', 'RDB') + bsr_name = bsr_name.replace('.weight', '.0.weight') + bsr_name = bsr_name.replace('.bias', '.0.bias') + new_state_dict[bsr_name] = v + else: + new_state_dict[k] = v + return new_state_dict diff --git a/ai-toolkit/toolkit/extension.py b/ai-toolkit/toolkit/extension.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f10e9d4ab3ed3f335d6e86f17faba17a40632d --- /dev/null +++ b/ai-toolkit/toolkit/extension.py @@ -0,0 +1,57 @@ +import os +import importlib +import pkgutil +from typing import List + +from toolkit.paths import TOOLKIT_ROOT + + +class Extension(object): + """Base class for extensions. + + Extensions are registered with the ExtensionManager, which is + responsible for calling the extension's load() and unload() + methods at the appropriate times. + + """ + + name: str = None + uid: str = None + + @classmethod + def get_process(cls): + # extend in subclass + pass + + +def get_all_extensions() -> List[Extension]: + extension_folders = ['extensions', 'extensions_built_in'] + + # This will hold the classes from all extension modules + all_extension_classes: List[Extension] = [] + + # Iterate over all directories (i.e., packages) in the "extensions" directory + for sub_dir in extension_folders: + extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) + for (_, name, _) in pkgutil.iter_modules([extensions_dir]): + # try: + # Import the module + module = importlib.import_module(f"{sub_dir}.{name}") + # Get the value of the AI_TOOLKIT_EXTENSIONS variable + extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) + # Check if the value is a list + if isinstance(extensions, list): + # Iterate over the list and add the classes to the main list + all_extension_classes.extend(extensions) + # except ImportError as e: + # print(f"Failed to import the {name} module. Error: {str(e)}") + + return all_extension_classes + + +def get_all_extensions_process_dict(): + all_extensions = get_all_extensions() + process_dict = {} + for extension in all_extensions: + process_dict[extension.uid] = extension.get_process() + return process_dict diff --git a/ai-toolkit/toolkit/guidance.py b/ai-toolkit/toolkit/guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..340b22e0673c0e791d1e97d39806b3ca7880d08a --- /dev/null +++ b/ai-toolkit/toolkit/guidance.py @@ -0,0 +1,831 @@ +import torch +from typing import Literal, Optional + +from toolkit.basic import value_map +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +from toolkit.config_modules import TrainConfig + +GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"] + +DIFFERENTIAL_SCALER = 0.2 + + +# DIFFERENTIAL_SCALER = 0.25 + + +def get_differential_mask( + conditional_latents: torch.Tensor, + unconditional_latents: torch.Tensor, + threshold: float = 0.2, + gradient: bool = False, +): + # make a differential mask + differential_mask = torch.abs(conditional_latents - unconditional_latents) + if len(differential_mask.shape) == 4: + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + elif len(differential_mask.shape) == 5: + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].max(dim=4, keepdim=True)[0] + differential_scaler = 1.0 / max_differential + differential_mask = differential_mask * differential_scaler + + if gradient: + # wew need to scale it to 0-1 + # differential_mask = differential_mask - differential_mask.min() + # differential_mask = differential_mask / differential_mask.max() + # add 0.2 threshold to both sides and clip + differential_mask = value_map( + differential_mask, + differential_mask.min(), + differential_mask.max(), + 0 - threshold, + 1 + threshold + ) + differential_mask = torch.clamp(differential_mask, 0.0, 1.0) + else: + + # make everything less than 0.2 be 0.0 and everything else be 1.0 + differential_mask = torch.where( + differential_mask < threshold, + torch.zeros_like(differential_mask), + torch.ones_like(differential_mask) + ) + return differential_mask + + +def get_targeted_polarity_loss( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True) + # noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True) + differential_scaler = DIFFERENTIAL_SCALER + + unconditional_diff = (unconditional_latents - conditional_latents) + unconditional_diff_noise = unconditional_diff * differential_scaler + conditional_diff = (conditional_latents - unconditional_latents) + conditional_diff_noise = conditional_diff * differential_scaler + conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False) + unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) + # + baseline_conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + baseline_unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + conditional_noise = noise + unconditional_diff_noise + unconditional_noise = noise + conditional_diff_noise + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + conditional_noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + unconditional_noise, + timesteps + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + # cat_baseline_noisy_latents = torch.cat( + # [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents], + # dim=0 + # ) + + # Disable the LoRA network so we can predict parent network knowledge without it + # sd.network.is_active = False + # sd.unet.eval() + + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + # baseline_prediction = sd.predict_noise( + # latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(), + # conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + # timestep=cat_timesteps, + # guidance_scale=1.0, + # **pred_kwargs # adapter residuals in here + # ).detach() + + # conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0) + + # negative_network_weights = [weight * -1.0 for weight in network_weight_list] + # positive_network_weights = [weight * 1.0 for weight in network_weight_list] + # cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + sd.unet.train() + # sd.network.is_active = True + + # sd.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + # prediction = prediction - baseline_prediction + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + # pred_pos = pred_pos - conditional_baseline_prediction + # pred_neg = pred_neg - unconditional_baseline_prediction + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + conditional_noise.float(), + reduction="none" + ) + pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + unconditional_noise.float(), + reduction="none" + ) + pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) + + loss = pred_loss + pred_neg_loss + + loss = loss.mean() + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + +def get_direct_guidance_loss( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, + **kwargs +): + with torch.no_grad(): + # Perform targeted guidance (working title) + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + # target_noise, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + # turn the LoRA network back on. + sd.unet.train() + # sd.network.is_active = True + + # sd.network.multiplier = network_weight_list + # do our prediction with LoRA active on the scaled guidance latents + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach() + unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) + + prediction = sd.predict_noise( + latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(), + conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=torch.cat([timesteps, timesteps]), + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0) + + guidance_scale = 1.1 + guidance_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + guidance_loss = torch.nn.functional.mse_loss( + guidance_pred.float(), + noise.detach().float(), + reduction="none" + ) + if mask_multiplier is not None: + guidance_loss = guidance_loss * mask_multiplier + + guidance_loss = guidance_loss.mean([1, 2, 3]) + + guidance_loss = guidance_loss.mean() + + # loss = guidance_loss + masked_noise_loss + loss = guidance_loss + + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + +# targeted +def get_targeted_guidance_loss( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + **kwargs +): + with torch.no_grad(): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # Encode the unconditional image into latents + unconditional_noisy_latents = sd.noise_scheduler.add_noise( + unconditional_latents, + noise, + timesteps + ) + conditional_noisy_latents = sd.noise_scheduler.add_noise( + conditional_latents, + noise, + timesteps + ) + + # was_network_active = self.network.is_active + sd.network.is_active = False + sd.unet.eval() + + target_differential = unconditional_latents - conditional_latents + # scale our loss by the differential scaler + target_differential_abs = target_differential.abs() + target_differential_abs_min = \ + target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + target_differential_abs_max = \ + target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + + min_guidance = 1.0 + max_guidance = 2.0 + + differential_scaler = value_map( + target_differential_abs, + target_differential_abs_min, + target_differential_abs_max, + min_guidance, + max_guidance + ).detach() + + + # With LoRA network bypassed, predict noise to get a baseline of what the network + # wants to do with the latents + noise. Pass our target latents here for the input. + target_unconditional = sd.predict_noise( + latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach() + prior_prediction_loss = torch.nn.functional.mse_loss( + target_unconditional.float(), + noise.float(), + reduction="none" + ).detach().clone() + + # turn the LoRA network back on. + sd.unet.train() + sd.network.is_active = True + sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list] + + # with LoRA active, predict the noise with the scaled differential latents added. This will allow us + # the opportunity to predict the differential + noise that was added to the latents. + prediction = sd.predict_noise( + latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(), + conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(), + timestep=torch.cat([timesteps, timesteps], dim=0), + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0) + + conditional_loss = torch.nn.functional.mse_loss( + prediction_conditional.float(), + noise.float(), + reduction="none" + ) + + unconditional_loss = torch.nn.functional.mse_loss( + prediction_unconditional.float(), + noise.float(), + reduction="none" + ) + + positive_loss = torch.abs( + conditional_loss.float() - prior_prediction_loss.float(), + ) + # scale our loss by the differential scaler + positive_loss = positive_loss * differential_scaler + + positive_loss = positive_loss.mean([1, 2, 3]) + + polar_loss = torch.abs( + conditional_loss.float() - unconditional_loss.float(), + ).mean([1, 2, 3]) + + + positive_loss = positive_loss.mean() + polar_loss.mean() + + + positive_loss.backward() + # loss = positive_loss.detach() + negative_loss.detach() + loss = positive_loss.detach() + + # add a grad so other backward does not fail + loss.requires_grad_(True) + + # restore network + sd.network.multiplier = network_weight_list + + return loss + +def get_guided_loss_polarity( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + train_config: 'TrainConfig', + scaler=None, + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + target_pos = noise + target_neg = noise + + if sd.is_flow_matching: + linear_timesteps = any([ + train_config.linear_timesteps, + train_config.linear_timesteps2, + train_config.timestep_type == 'linear', + ]) + + timestep_type = 'linear' if linear_timesteps else None + if timestep_type is None: + timestep_type = train_config.timestep_type + + sd.noise_scheduler.set_train_timesteps( + 1000, + device=device, + timestep_type=timestep_type, + latents=conditional_latents + ) + target_pos = (noise - conditional_latents).detach() + target_neg = (noise - unconditional_latents).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch) + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch) + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + negative_network_weights = [weight * -1.0 for weight in network_weight_list] + positive_network_weights = [weight * 1.0 for weight in network_weight_list] + cat_network_weight_list = positive_network_weights + negative_network_weights + + # turn the LoRA network back on. + sd.unet.train() + sd.network.is_active = True + + sd.network.multiplier = cat_network_weight_list + + # do our prediction with LoRA active on the scaled guidance latents + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) + + pred_loss = torch.nn.functional.mse_loss( + pred_pos.float(), + target_pos.float(), + reduction="none" + ) + # pred_loss = pred_loss.mean([1, 2, 3]) + + pred_neg_loss = torch.nn.functional.mse_loss( + pred_neg.float(), + target_neg.float(), + reduction="none" + ) + + loss = pred_loss + pred_neg_loss + + loss = loss.mean([1, 2, 3]) + loss = loss.mean() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + + +def get_guided_tnt( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + prior_pred: torch.Tensor = None, + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + + # turn the LoRA network back on. + sd.unet.train() + if sd.network is not None: + cat_network_weight_list = [weight for weight in network_weight_list * 2] + sd.network.multiplier = cat_network_weight_list + sd.network.is_active = True + + + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0) + + this_loss = torch.nn.functional.mse_loss( + this_prediction.float(), + noise.float(), + reduction="none" + ) + + that_loss = torch.nn.functional.mse_loss( + that_prediction.float(), + noise.float(), + reduction="none" + ) + + this_loss = this_loss.mean([1, 2, 3]) + # negative loss on that + that_loss = -that_loss.mean([1, 2, 3]) + + with torch.no_grad(): + # match that loss with this loss so it is not a negative value and same scale + that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss) + + that_loss = that_loss * that_loss_scaler * 0.01 + + loss = this_loss + that_loss + + loss = loss.mean() + + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + +def targeted_flow_guidance( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, + scaler=None, + train_config=None, + **kwargs +): + if not sd.is_flow_matching: + raise ValueError("targeted_flow only works on flow matching models") + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # get a mask on the differential of the latents + # this will be scaled from 0.0-1.0 with 1.0 being the largest differential + abs_differential_mask = get_differential_mask( + conditional_latents, + unconditional_latents, + gradient=True + ) + + # get noisy latents for both conditional and unconditional predictions + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch) + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch) + + # disable the lora to get a baseline prediction + sd.network.is_active = False + sd.unet.eval() + + # get a baseline prediction of the model knowledge without the lora network + # we do this with the unconditional noisy latents + baseline_prediction = sd.predict_noise( + latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs + ).detach() + + # This is our normal flowmatching target + # target = noise - latents + # we need to target the baseline noise but with our conditional latents + # to do this we first have to determine the baseline_prediction noise by reversing the flowmatching target + baseline_predicted_noise = baseline_prediction + unconditional_latents + + # baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image. + # we use this as our new noise target to preserve the existing knowledge of the image. + # we apply a mask to this noise to only allow the differential of the conditional latents to be learned + baseline_predicted_noise = (1 - abs_differential_mask) * baseline_predicted_noise + masked_noise = abs_differential_mask * noise + target_noise = masked_noise + baseline_predicted_noise + + # compute our new target prediction using our current knowledge noise with our conditional latents + # this makes it so the only new information is the differential of our conditional and unconditional latents + # forcing the network to preserve existing knowledge, but learn only our changes + target_pred = (target_noise - conditional_latents).detach() + + # make a prediction with the lora network active + sd.unet.train() + sd.network.is_active = True + sd.network.multiplier = network_weight_list + prediction = sd.predict_noise( + latents=conditional_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs + ) + + # target our baseline + diffirential noise target + pred_loss = torch.nn.functional.mse_loss( + prediction.float(), + target_pred.float() + ) + + return pred_loss + + +# this processes all guidance losses based on the batch information +def get_guidance_loss( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, + scaler=None, + train_config=None, + **kwargs +): + # TODO add others and process individual batch items separately + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + + if guidance_type == "targeted": + assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance" + return get_targeted_guidance_loss( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + **kwargs + ) + elif guidance_type == "polarity": + assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" + return get_guided_loss_polarity( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + scaler=scaler, + train_config=train_config, + **kwargs + ) + elif guidance_type == "tnt": + assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" + return get_guided_tnt( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + prior_pred=prior_pred, + **kwargs + ) + + elif guidance_type == "targeted_polarity": + assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance" + return get_targeted_polarity_loss( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + **kwargs + ) + elif guidance_type == "direct": + return get_direct_guidance_loss( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + **kwargs + ) + elif guidance_type == "targeted_flow": + return targeted_flow_guidance( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + scaler=scaler, + train_config=train_config, + **kwargs + ) + else: + raise NotImplementedError(f"Guidance type {guidance_type} is not implemented") diff --git a/ai-toolkit/toolkit/ideogram_caption.py b/ai-toolkit/toolkit/ideogram_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..6e77fb0df494d2b18a496277f69ab44f0f2baeeb --- /dev/null +++ b/ai-toolkit/toolkit/ideogram_caption.py @@ -0,0 +1,319 @@ +"""Shared helpers for Ideogram-4 structured JSON captions. + +This is the single source of truth for the caption schema so the captioner, the +prompt upsampler, the dataloader, and the model encoder all agree. It encodes the +official Ideogram-4 rules and, crucially, MIGRATES the old caption format we used +before those rules were published into the new one ("digest" old, emit new). + +Official schema (summary): +- three top-level keys: high_level_description (optional), style_description + (optional), compositional_deconstruction (required). +- style_description holds EXACTLY ONE of `photo` (photographs) or `art_style` + (illustration/painting/3D/graphic design), never both. Key order is strict and + branch-dependent: + photo branch: aesthetics, lighting, photo, medium, color_palette + non-photo branch: aesthetics, lighting, medium, art_style, color_palette +- medium is one of: photograph, illustration, 3d_render, painting, graphic_design +- color_palette: UPPERCASE #RRGGBB only, up to 16 per image / 5 per element. +- elements, strict key order: + obj: type, bbox, desc, color_palette + text: type, bbox, text, desc, color_palette + bbox is optional, normalized 0-1000, [y_min, x_min, y_max, x_max], top-left. +- serialize compact: separators=(",", ":"), ensure_ascii=False (no \\uXXXX). + +The OLD format we previously emitted differed by: always using `photo` (even for +non-photo media), putting `color_palette` before `desc`/`text`, title-cased medium +with a trailing period ("Illustration."), and lowercase / 3-digit hex. Every +function here accepts the old shape and returns the new one. +""" + +import json +import re +from collections import OrderedDict + +MAX_IMAGE_PALETTE = 16 # style_description.color_palette +MAX_ELEMENT_PALETTE = 5 # per-element color_palette + +# Canonical medium tokens (official set). +MEDIUM_OPTIONS = [ + "photograph", + "illustration", + "3d_render", + "painting", + "graphic_design", +] + +# Map common variants (including our old "Title." style) to the canonical token. +# Anything not listed is treated as a custom medium and preserved verbatim. +_MEDIUM_ALIASES = { + "photograph": "photograph", + "photo": "photograph", + "illustration": "illustration", + "3d render": "3d_render", + "3d_render": "3d_render", + "3d-render": "3d_render", + "3drender": "3d_render", + "render": "3d_render", + "3d": "3d_render", + "painting": "painting", + "graphic design": "graphic_design", + "graphic_design": "graphic_design", + "graphic-design": "graphic_design", + "graphic": "graphic_design", +} + +_HEX6_RE = re.compile(r"^#[0-9a-fA-F]{6}$") +_HEX3_RE = re.compile(r"^#[0-9a-fA-F]{3}$") + + +def canon_medium(medium): + """Canonicalize a medium string to an official token when recognized, + otherwise return it stripped (custom mediums are allowed, preserved as-is).""" + if not isinstance(medium, str): + return medium + key = medium.strip().rstrip(".").strip().lower() + if key in _MEDIUM_ALIASES: + return _MEDIUM_ALIASES[key] + return medium.strip() + + +def is_photo_medium(medium): + """True for the photograph branch (uses `photo`), False for the art_style branch.""" + return canon_medium(medium) == "photograph" + + +def normalize_hex(color): + """Return an UPPERCASE #RRGGBB string, expanding #RGB -> #RRGGBB. None if invalid.""" + if not isinstance(color, str): + return None + s = color.strip() + if _HEX6_RE.match(s): + return "#" + s[1:].upper() + if _HEX3_RE.match(s): + return "#" + "".join(ch * 2 for ch in s[1:]).upper() + return None + + +def sanitize_palette(palette, max_len): + """Keep unique, valid, UPPERCASE hex colors in order, capped to max_len. + Returns the cleaned list, or None if nothing valid remains (drop the key).""" + if not isinstance(palette, (list, tuple)): + return None + seen = set() + out = [] + for c in palette: + h = normalize_hex(c) + if h is None or h in seen: + continue + seen.add(h) + out.append(h) + if len(out) >= max_len: + break + return out or None + + +def normalize_style(style): + """Reorder/clean style_description into the correct branch (photo vs art_style) + with the strict key order, canonical medium, and uppercase palette. Accepts the + old shape (always `photo`) and migrates it based on the medium.""" + if not isinstance(style, dict): + return style + + raw_medium = style.get("medium") + medium = canon_medium(raw_medium) if raw_medium is not None else None + has_photo = bool(style.get("photo")) + has_art = bool(style.get("art_style")) + + # Decide the branch. A recognized medium is authoritative; otherwise infer from + # whichever render key the (old) data already had, defaulting to photo. + if medium in MEDIUM_OPTIONS: + photo_branch = medium == "photograph" + elif has_art and not has_photo: + photo_branch = False + else: + photo_branch = True + + photo_val = style.get("photo") if has_photo else None + art_val = style.get("art_style") if has_art else None + + out = OrderedDict() + if "aesthetics" in style: + out["aesthetics"] = style["aesthetics"] + if "lighting" in style: + out["lighting"] = style["lighting"] + + if photo_branch: + # aesthetics, lighting, photo, medium, color_palette + val = photo_val if photo_val is not None else art_val + if val is not None: + out["photo"] = val + if medium is not None: + out["medium"] = medium + else: + # aesthetics, lighting, medium, art_style, color_palette + if medium is not None: + out["medium"] = medium + val = art_val if art_val is not None else photo_val + if val is not None: + out["art_style"] = val + + pal = sanitize_palette(style.get("color_palette"), MAX_IMAGE_PALETTE) + if pal is not None: + out["color_palette"] = pal + + # Preserve any unexpected extra keys at the end rather than dropping them. + for k, v in style.items(): + if k not in ( + "aesthetics", + "lighting", + "photo", + "art_style", + "medium", + "color_palette", + ): + out[k] = v + return out + + +def normalize_element(el): + """Reorder an element's keys to the strict schema order and uppercase its + palette. obj: type, bbox, desc, color_palette. text: type, bbox, text, desc, + color_palette. bbox is kept verbatim (already [y1,x1,y2,x2] in stored form).""" + if not isinstance(el, dict): + return el + etype = el.get("type", "obj") + out = OrderedDict() + out["type"] = etype + if el.get("bbox") is not None: + out["bbox"] = el["bbox"] + if etype == "text": + if "text" in el: + out["text"] = el["text"] + if "desc" in el: + out["desc"] = el["desc"] + else: + if "desc" in el: + out["desc"] = el["desc"] + pal = sanitize_palette(el.get("color_palette"), MAX_ELEMENT_PALETTE) + if pal is not None: + out["color_palette"] = pal + # Preserve any extras (e.g. future keys) at the end. + for k, v in el.items(): + if k not in out and k != "color_palette": + out[k] = v + return out + + +def normalize_caption_dict(data): + """Normalize a parsed caption dict in place-ish: drop input-only aspect_ratio, + enforce top-level key order, normalize style (photo/art_style branch) and every + element. Returns a new OrderedDict. Accepts old-format captions and emits new.""" + if not isinstance(data, dict): + return data + data.pop("aspect_ratio", None) # input-only context, never part of output + + out = OrderedDict() + if "high_level_description" in data: + out["high_level_description"] = data["high_level_description"] + if "style_description" in data: + out["style_description"] = normalize_style(data["style_description"]) + + decon = data.get("compositional_deconstruction") + if isinstance(decon, dict): + nd = OrderedDict() + if "background" in decon: + nd["background"] = decon["background"] + els = decon.get("elements") + if isinstance(els, list): + nd["elements"] = [normalize_element(e) for e in els] + for k, v in decon.items(): + if k not in ("background", "elements"): + nd[k] = v + out["compositional_deconstruction"] = nd + elif decon is not None: + out["compositional_deconstruction"] = decon + + for k, v in data.items(): + if k not in ( + "high_level_description", + "style_description", + "compositional_deconstruction", + "aspect_ratio", + ): + out[k] = v + return out + + +# --- bbox coordinate adaptation that does NOT require valid JSON ------------- +# Captioners emit boxes as [x1,y1,x2,y2] but we store [y1,x1,y2,x2]. The +# structured normalizer can only swap per-element when the JSON parses; if the +# model returns malformed JSON, that path is skipped and the boxes stay in the +# wrong order. This regex rewrites every `"bbox":[...]` array in the raw text +# directly, so the swap still happens on un-parseable output. +_BBOX_TEXT_RE = re.compile( + r'"bbox"\s*:\s*\[\s*' + r"(-?\d+(?:\.\d+)?)\s*,\s*" + r"(-?\d+(?:\.\d+)?)\s*,\s*" + r"(-?\d+(?:\.\d+)?)\s*,\s*" + r"(-?\d+(?:\.\d+)?)\s*\]" +) + + +def _clamp_1000(v): + return max(0, min(1000, round(float(v)))) + + +def swap_bbox_xy_in_text(text): + """Swap every [x1,y1,x2,y2] bbox to the stored [y1,x1,y2,x2] order directly in + the raw model output -- clamping each value to 0-1000 and ordering each axis + pair. It never parses the surrounding JSON, so it works even when the output is + malformed. Only `"bbox":[n,n,n,n]` arrays are touched; everything else is left + byte-for-byte. Returns the rewritten text.""" + if not isinstance(text, str): + return text + + def _repl(m): + x1, y1, x2, y2 = m.groups() + cx1, cx2 = sorted((_clamp_1000(x1), _clamp_1000(x2))) + cy1, cy2 = sorted((_clamp_1000(y1), _clamp_1000(y2))) + return f'"bbox":[{cy1},{cx1},{cy2},{cx2}]' + + return _BBOX_TEXT_RE.sub(_repl, text) + + +def is_ideogram_caption_str(text): + """True if text parses as a JSON object with a compositional_deconstruction block.""" + t = (text or "").strip() + if not t.startswith("{"): + return False + try: + d = json.loads(t) + except Exception: + return False + return isinstance(d, dict) and isinstance( + d.get("compositional_deconstruction"), dict + ) + + +def to_model_string(data): + """Serialize a caption dict to the compact, model-ready string the renderer wants.""" + return json.dumps(data, ensure_ascii=False, separators=(",", ":")) + + +def digest_caption_string(text): + """Parse, normalize (migrating old format), and return the compact model-ready + string. Returns the input unchanged if it is not an Ideogram structured caption + (plain-text captions pass straight through).""" + t = (text or "").strip() + if not t.startswith("{"): + return text + try: + data = json.loads(t, object_pairs_hook=OrderedDict) + except Exception: + return text + if not ( + isinstance(data, dict) + and isinstance(data.get("compositional_deconstruction"), dict) + ): + return text + return to_model_string(normalize_caption_dict(data)) diff --git a/ai-toolkit/toolkit/image_utils.py b/ai-toolkit/toolkit/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68844e239f9f4d3ce9aa96815be04e45e15a8595 --- /dev/null +++ b/ai-toolkit/toolkit/image_utils.py @@ -0,0 +1,547 @@ +# ref https://github.com/scardine/image_size/blob/master/get_image_size.py +import atexit +import collections +import json +import os +import io +import struct +import threading +from typing import TYPE_CHECKING + +import cv2 +import numpy as np +import torch +from diffusers import AutoencoderTiny +from PIL import Image as PILImage + +FILE_UNKNOWN = "Sorry, don't know how to get size for this file." + + +class UnknownImageFormat(Exception): + pass + + +types = collections.OrderedDict() +BMP = types['BMP'] = 'BMP' +GIF = types['GIF'] = 'GIF' +ICO = types['ICO'] = 'ICO' +JPEG = types['JPEG'] = 'JPEG' +PNG = types['PNG'] = 'PNG' +TIFF = types['TIFF'] = 'TIFF' + +image_fields = ['path', 'type', 'file_size', 'width', 'height'] + + +class Image(collections.namedtuple('Image', image_fields)): + + def to_str_row(self): + return ("%d\t%d\t%d\t%s\t%s" % ( + self.width, + self.height, + self.file_size, + self.type, + self.path.replace('\t', '\\t'), + )) + + def to_str_row_verbose(self): + return ("%d\t%d\t%d\t%s\t%s\t##%s" % ( + self.width, + self.height, + self.file_size, + self.type, + self.path.replace('\t', '\\t'), + self)) + + def to_str_json(self, indent=None): + return json.dumps(self._asdict(), indent=indent) + + +def get_image_size(file_path): + """ + Return (width, height) for a given img file content - no external + dependencies except the os and struct builtin modules + """ + img = get_image_metadata(file_path) + return (img.width, img.height) + + +def get_image_size_from_bytesio(input, size): + """ + Return (width, height) for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + input (io.IOBase): io object support read & seek + size (int): size of buffer in byte + """ + img = get_image_metadata_from_bytesio(input, size) + return (img.width, img.height) + + +def get_image_metadata(file_path): + """ + Return an `Image` object for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + file_path (str): path to an image file + + Returns: + Image: (path, type, file_size, width, height) + """ + size = os.path.getsize(file_path) + + # be explicit with open arguments - we need binary mode + with io.open(file_path, "rb") as input: + return get_image_metadata_from_bytesio(input, size, file_path) + + +def get_image_metadata_from_bytesio(input, size, file_path=None): + """ + Return an `Image` object for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + input (io.IOBase): io object support read & seek + size (int): size of buffer in byte + file_path (str): path to an image file + + Returns: + Image: (path, type, file_size, width, height) + """ + height = -1 + width = -1 + data = input.read(26) + msg = " raised while trying to decode as JPEG." + + if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'): + # GIFs + imgtype = GIF + w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n') + and (data[12:16] == b'IHDR')): + # PNGs + imgtype = PNG + w, h = struct.unpack(">LL", data[16:24]) + width = int(w) + height = int(h) + elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'): + # older PNGs + imgtype = PNG + w, h = struct.unpack(">LL", data[8:16]) + width = int(w) + height = int(h) + elif (size >= 2) and data.startswith(b'\377\330'): + # JPEG + imgtype = JPEG + input.seek(0) + input.read(2) + b = input.read(1) + try: + while (b and ord(b) != 0xDA): + while (ord(b) != 0xFF): + b = input.read(1) + while (ord(b) == 0xFF): + b = input.read(1) + if (ord(b) >= 0xC0 and ord(b) <= 0xC3): + input.read(3) + h, w = struct.unpack(">HH", input.read(4)) + break + else: + input.read( + int(struct.unpack(">H", input.read(2))[0]) - 2) + b = input.read(1) + width = int(w) + height = int(h) + except struct.error: + raise UnknownImageFormat("StructError" + msg) + except ValueError: + raise UnknownImageFormat("ValueError" + msg) + except Exception as e: + raise UnknownImageFormat(e.__class__.__name__ + msg) + elif (size >= 26) and data.startswith(b'BM'): + # BMP + imgtype = 'BMP' + headersize = struct.unpack("= 40: + w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"): + # Standard TIFF, big- or little-endian + # BigTIFF and other different but TIFF-like formats are not + # supported currently + imgtype = TIFF + byteOrder = data[:2] + boChar = ">" if byteOrder == "MM" else "<" + # maps TIFF type id to size (in bytes) + # and python format char for struct + tiffTypes = { + 1: (1, boChar + "B"), # BYTE + 2: (1, boChar + "c"), # ASCII + 3: (2, boChar + "H"), # SHORT + 4: (4, boChar + "L"), # LONG + 5: (8, boChar + "LL"), # RATIONAL + 6: (1, boChar + "b"), # SBYTE + 7: (1, boChar + "c"), # UNDEFINED + 8: (2, boChar + "h"), # SSHORT + 9: (4, boChar + "l"), # SLONG + 10: (8, boChar + "ll"), # SRATIONAL + 11: (4, boChar + "f"), # FLOAT + 12: (8, boChar + "d") # DOUBLE + } + ifdOffset = struct.unpack(boChar + "L", data[4:8])[0] + try: + countSize = 2 + input.seek(ifdOffset) + ec = input.read(countSize) + ifdEntryCount = struct.unpack(boChar + "H", ec)[0] + # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4 + # bytes: value offset + ifdEntrySize = 12 + for i in range(ifdEntryCount): + entryOffset = ifdOffset + countSize + i * ifdEntrySize + input.seek(entryOffset) + tag = input.read(2) + tag = struct.unpack(boChar + "H", tag)[0] + if (tag == 256 or tag == 257): + # if type indicates that value fits into 4 bytes, value + # offset is not an offset but value itself + type = input.read(2) + type = struct.unpack(boChar + "H", type)[0] + if type not in tiffTypes: + raise UnknownImageFormat( + "Unkown TIFF field type:" + + str(type)) + typeSize = tiffTypes[type][0] + typeChar = tiffTypes[type][1] + input.seek(entryOffset + 8) + value = input.read(typeSize) + value = int(struct.unpack(typeChar, value)[0]) + if tag == 256: + width = value + else: + height = value + if width > -1 and height > -1: + break + except Exception as e: + raise UnknownImageFormat(str(e)) + elif size >= 2: + # see http://en.wikipedia.org/wiki/ICO_(file_format) + imgtype = 'ICO' + input.seek(0) + reserved = input.read(2) + if 0 != struct.unpack(" 1: + import warnings + warnings.warn("ICO File contains more than one image") + # http://msdn.microsoft.com/en-us/library/ms997538.aspx + w = input.read(1) + h = input.read(1) + width = ord(w) + height = ord(h) + else: + raise UnknownImageFormat(FILE_UNKNOWN) + + return Image(path=file_path, + type=imgtype, + file_size=size, + width=width, + height=height) + + +import unittest + + +class Test_get_image_size(unittest.TestCase): + data = [{ + 'path': 'lookmanodeps.png', + 'width': 251, + 'height': 208, + 'file_size': 22228, + 'type': 'PNG'}] + + def setUp(self): + pass + + def test_get_image_size_from_bytesio(self): + img = self.data[0] + p = img['path'] + with io.open(p, 'rb') as fp: + b = fp.read() + fp = io.BytesIO(b) + sz = len(b) + output = get_image_size_from_bytesio(fp, sz) + self.assertTrue(output) + self.assertEqual(output, + (img['width'], + img['height'])) + + def test_get_image_metadata_from_bytesio(self): + img = self.data[0] + p = img['path'] + with io.open(p, 'rb') as fp: + b = fp.read() + fp = io.BytesIO(b) + sz = len(b) + output = get_image_metadata_from_bytesio(fp, sz) + self.assertTrue(output) + for field in image_fields: + self.assertEqual(getattr(output, field), None if field == 'path' else img[field]) + + def test_get_image_metadata(self): + img = self.data[0] + output = get_image_metadata(img['path']) + self.assertTrue(output) + for field in image_fields: + self.assertEqual(getattr(output, field), img[field]) + + def test_get_image_metadata__ENOENT_OSError(self): + with self.assertRaises(OSError): + get_image_metadata('THIS_DOES_NOT_EXIST') + + def test_get_image_metadata__not_an_image_UnknownImageFormat(self): + with self.assertRaises(UnknownImageFormat): + get_image_metadata('README.rst') + + def test_get_image_size(self): + img = self.data[0] + output = get_image_size(img['path']) + self.assertTrue(output) + self.assertEqual(output, + (img['width'], + img['height'])) + + def tearDown(self): + pass + + +def main(argv=None): + """ + Print image metadata fields for the given file path. + + Keyword Arguments: + argv (list): commandline arguments (e.g. sys.argv[1:]) + Returns: + int: zero for OK + """ + import logging + import optparse + import sys + + prs = optparse.OptionParser( + usage="%prog [-v|--verbose] [--json|--json-indent] []", + description="Print metadata for the given image paths " + "(without image library bindings).") + + prs.add_option('--json', + dest='json', + action='store_true') + prs.add_option('--json-indent', + dest='json_indent', + action='store_true') + + prs.add_option('-v', '--verbose', + dest='verbose', + action='store_true', ) + prs.add_option('-q', '--quiet', + dest='quiet', + action='store_true', ) + prs.add_option('-t', '--test', + dest='run_tests', + action='store_true', ) + + argv = list(argv) if argv is not None else sys.argv[1:] + (opts, args) = prs.parse_args(args=argv) + loglevel = logging.INFO + if opts.verbose: + loglevel = logging.DEBUG + elif opts.quiet: + loglevel = logging.ERROR + logging.basicConfig(level=loglevel) + log = logging.getLogger() + log.debug('argv: %r', argv) + log.debug('opts: %r', opts) + log.debug('args: %r', args) + + if opts.run_tests: + import sys + sys.argv = [sys.argv[0]] + args + import unittest + return unittest.main() + + output_func = Image.to_str_row + if opts.json_indent: + import functools + output_func = functools.partial(Image.to_str_json, indent=2) + elif opts.json: + output_func = Image.to_str_json + elif opts.verbose: + output_func = Image.to_str_row_verbose + + EX_OK = 0 + EX_NOT_OK = 2 + + if len(args) < 1: + prs.print_help() + print('') + prs.error("You must specify one or more paths to image files") + + errors = [] + for path_arg in args: + try: + img = get_image_metadata(path_arg) + print(output_func(img)) + except KeyboardInterrupt: + raise + except OSError as e: + log.error((path_arg, e)) + errors.append((path_arg, e)) + except Exception as e: + log.exception(e) + errors.append((path_arg, e)) + pass + if len(errors): + import pprint + print("ERRORS", file=sys.stderr) + print("======", file=sys.stderr) + print(pprint.pformat(errors, indent=2), file=sys.stderr) + return EX_NOT_OK + return EX_OK + + +is_window_shown = False +display_lock = threading.Lock() +current_img = None +update_event = threading.Event() + +def update_image(img, name): + global current_img + with display_lock: + current_img = (img, name) + update_event.set() + +def display_image_in_thread(): + global is_window_shown + + def display_img(): + global current_img + while True: + update_event.wait() + with display_lock: + if current_img: + img, name = current_img + cv2.imshow(name, img) + current_img = None + update_event.clear() + if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop + cv2.destroyAllWindows() + print('\nESC pressed, stopping') + break + + if not is_window_shown: + is_window_shown = True + threading.Thread(target=display_img, daemon=True).start() + + +def show_img(img, name='AI Toolkit'): + img = np.clip(img, 0, 255).astype(np.uint8) + update_image(img[:, :, ::-1], name) + if not is_window_shown: + display_image_in_thread() + + +def show_tensors(imgs: torch.Tensor, name='AI Toolkit'): + if len(imgs.shape) == 4: + img_list = torch.chunk(imgs, imgs.shape[0], dim=0) + else: + img_list = [imgs] + + img = torch.cat(img_list, dim=3) + img = img / 2 + 0.5 + img_numpy = img.to(torch.float32).detach().cpu().numpy() + img_numpy = np.clip(img_numpy, 0, 1) * 255 + img_numpy = img_numpy.transpose(0, 2, 3, 1) + img_numpy = img_numpy.astype(np.uint8) + + show_img(img_numpy[0], name=name) + +def save_tensors(imgs: torch.Tensor, path='output.png', fps=None): + if len(imgs.shape) == 5 and imgs.shape[0] == 1: + imgs = imgs.squeeze(0) + if len(imgs.shape) == 4: + img_list = torch.chunk(imgs, imgs.shape[0], dim=0) + else: + img_list = [imgs] + + num_frames = len(img_list) + print(f"Saving {num_frames} frames to {path} at {fps} fps") + if fps is not None and num_frames > 1: + img = torch.cat(img_list, dim=0) + else: + img = torch.cat(img_list, dim=3) + img = img / 2 + 0.5 + img_numpy = img.to(torch.float32).detach().cpu().numpy() + img_numpy = np.clip(img_numpy, 0, 1) * 255 + img_numpy = img_numpy.transpose(0, 2, 3, 1) + img_numpy = img_numpy.astype(np.uint8) + + if fps is not None and num_frames > 1: + img_list = [PILImage.fromarray(img_numpy[i]) for i in range(num_frames)] + duration = int(1000 / fps) + img_list[0].save(path, save_all=True, append_images=img_list[1:], duration=duration, loop=0, quality=95) + else: + # concat images to one + img_numpy = np.concatenate(img_numpy, axis=1) + # conver to pil + img_pil = PILImage.fromarray(img_numpy) + img_pil.save(path) + +def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'): + if vae.device == 'cpu': + vae.to(latents.device) + latents = latents / vae.config['scaling_factor'] + imgs = vae.decode(latents).sample + show_tensors(imgs, name=name) + + +def on_exit(): + if is_window_shown: + cv2.destroyAllWindows() + + +def reduce_contrast(tensor, factor): + # Ensure factor is between 0 and 1 + factor = max(0, min(factor, 1)) + + # Calculate the mean of the tensor + mean = torch.mean(tensor) + + # Reduce contrast + adjusted_tensor = (tensor - mean) * factor + mean + + # Clip values to ensure they stay within -1 to 1 range + return torch.clamp(adjusted_tensor, -1.0, 1.0) + +atexit.register(on_exit) + +if __name__ == "__main__": + import sys + + sys.exit(main(argv=sys.argv[1:])) diff --git a/ai-toolkit/toolkit/inversion_utils.py b/ai-toolkit/toolkit/inversion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51a61d83bd8d62efbfa2c3b99069f7c6bb0f81ca --- /dev/null +++ b/ai-toolkit/toolkit/inversion_utils.py @@ -0,0 +1,410 @@ +# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py + +import torch +import os +from tqdm import tqdm + +from toolkit import train_tools +from toolkit.prompt_utils import PromptEmbeds +from toolkit.stable_diffusion_model import StableDiffusion + + +def mu_tilde(model, xt, x0, timestep): + "mu_tilde(x_t, x_0) DDPM paper eq. 7" + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + alpha_prod_t_prev = model.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod + alpha_t = model.scheduler.alphas[timestep] + beta_t = 1 - alpha_t + alpha_bar = model.scheduler.alphas_cumprod[timestep] + return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + ( + (alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt + + +def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50): + """ + Samples from P(x_1:T|x_0) + """ + # torch.manual_seed(43256465436) + alpha_bar = sd.noise_scheduler.alphas_cumprod + sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5 + alphas = sd.noise_scheduler.alphas + betas = 1 - alphas + # variance_noise_shape = ( + # num_inference_steps, + # sd.unet.in_channels, + # sd.unet.sample_size, + # sd.unet.sample_size) + variance_noise_shape = list(sample.shape) + variance_noise_shape[0] = num_inference_steps + + timesteps = sd.noise_scheduler.timesteps.to(sd.device) + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16) + for t in reversed(timesteps): + idx = t_to_idx[int(t)] + xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t] + xts = torch.cat([xts, sample], dim=0) + + return xts + + +def encode_text(model, prompts): + text_input = model.tokenizer( + prompts, + padding="max_length", + max_length=model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + with torch.no_grad(): + text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] + return text_encoding + + +def forward_step(sd: StableDiffusion, model_output, timestep, sample): + next_timestep = min( + sd.noise_scheduler.config['num_train_timesteps'] - 2, + timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + ) + + # 2. compute alphas, betas + alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep] + # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 5. TODO: simple noising implementation + next_sample = sd.noise_scheduler.add_noise( + pred_original_sample, + model_output, + torch.LongTensor([next_timestep])) + return next_sample + + +def get_variance(sd: StableDiffusion, timestep): # , prev_timestep): + prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + return variance + + +def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor): + VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1) + if sd.is_xl: + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids + else: + return None + + +def inversion_forward_process( + sd: StableDiffusion, + sample: torch.Tensor, + conditional_embeddings: PromptEmbeds, + unconditional_embeddings: PromptEmbeds, + etas=None, + prog_bar=False, + cfg_scale=3.5, + num_inference_steps=50, eps=None +): + current_num_timesteps = len(sd.noise_scheduler.timesteps) + sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device) + + timesteps = sd.noise_scheduler.timesteps.to(sd.device) + # variance_noise_shape = ( + # num_inference_steps, + # sd.unet.in_channels, + # sd.unet.sample_size, + # sd.unet.sample_size + # ) + variance_noise_shape = list(sample.shape) + variance_noise_shape[0] = num_inference_steps + if etas is None or (type(etas) in [int, float] and etas == 0): + eta_is_zero = True + zs = None + else: + eta_is_zero = False + if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps + xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps) + alpha_bar = sd.noise_scheduler.alphas_cumprod + zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16) + + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + noisy_sample = sample + op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps) + + for timestep in op: + idx = t_to_idx[int(timestep)] + # 1. predict noise residual + if not eta_is_zero: + noisy_sample = xts[idx][None] + + added_cond_kwargs = {} + + with torch.no_grad(): + text_embeddings = train_tools.concat_prompt_embeddings( + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + 1, # batch size + ) + if sd.is_xl: + add_time_ids = get_time_ids_from_latents(sd, noisy_sample) + # add extra for cfg + add_time_ids = torch.cat( + [add_time_ids] * 2, dim=0 + ) + + added_cond_kwargs = { + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + # double up for cfg + latent_model_input = torch.cat( + [noisy_sample] * 2, dim=0 + ) + + noise_pred = sd.unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + # out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding) + # cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings) + + noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) + + if eta_is_zero: + # 2. compute more noisy image and set x_t -> x_t+1 + noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample) + xts = None + + else: + xtm1 = xts[idx + 1][None] + # pred of x0 + pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[ + timestep] ** 0.5 + + # direction to xt + prev_timestep = timestep - sd.noise_scheduler.config[ + 'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod + + variance = get_variance(sd, timestep) + pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred + + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) + zs[idx] = z + + # correction to avoid error accumulation + xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z + xts[idx + 1] = xtm1 + + if not zs is None: + zs[-1] = torch.zeros_like(zs[-1]) + + # restore timesteps + sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device) + + return noisy_sample, zs, xts + + +# +# def inversion_forward_process( +# model, +# sample, +# etas=None, +# prog_bar=False, +# prompt="", +# cfg_scale=3.5, +# num_inference_steps=50, eps=None +# ): +# if not prompt == "": +# text_embeddings = encode_text(model, prompt) +# uncond_embedding = encode_text(model, "") +# timesteps = model.scheduler.timesteps.to(model.device) +# variance_noise_shape = ( +# num_inference_steps, +# model.unet.in_channels, +# model.unet.sample_size, +# model.unet.sample_size) +# if etas is None or (type(etas) in [int, float] and etas == 0): +# eta_is_zero = True +# zs = None +# else: +# eta_is_zero = False +# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps +# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps) +# alpha_bar = model.scheduler.alphas_cumprod +# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16) +# +# t_to_idx = {int(v): k for k, v in enumerate(timesteps)} +# noisy_sample = sample +# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps) +# +# for t in op: +# idx = t_to_idx[int(t)] +# # 1. predict noise residual +# if not eta_is_zero: +# noisy_sample = xts[idx][None] +# +# with torch.no_grad(): +# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding) +# if not prompt == "": +# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings) +# +# if not prompt == "": +# ## classifier free guidance +# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample) +# else: +# noise_pred = out.sample +# +# if eta_is_zero: +# # 2. compute more noisy image and set x_t -> x_t+1 +# noisy_sample = forward_step(model, noise_pred, t, noisy_sample) +# +# else: +# xtm1 = xts[idx + 1][None] +# # pred of x0 +# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5 +# +# # direction to xt +# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps +# alpha_prod_t_prev = model.scheduler.alphas_cumprod[ +# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod +# +# variance = get_variance(model, t) +# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred +# +# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction +# +# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) +# zs[idx] = z +# +# # correction to avoid error accumulation +# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z +# xts[idx + 1] = xtm1 +# +# if not zs is None: +# zs[-1] = torch.zeros_like(zs[-1]) +# +# return noisy_sample, zs, xts + + +def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None): + # 1. get previous step value (=t-1) + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + # 2. compute alphas, betas + alpha_prod_t = model.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = model.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # variance = self.scheduler._get_variance(timestep, prev_timestep) + variance = get_variance(model, timestep) # , prev_timestep) + std_dev_t = eta * variance ** (0.5) + # Take care of asymetric reverse process (asyrp) + model_output_direction = model_output + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction + pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + # 8. Add noice if eta > 0 + if eta > 0: + if variance_noise is None: + variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16) + sigma_z = eta * variance ** (0.5) * variance_noise + prev_sample = prev_sample + sigma_z + + return prev_sample + + +def inversion_reverse_process( + model, + xT, + etas=0, + prompts="", + cfg_scales=None, + prog_bar=False, + zs=None, + controller=None, + asyrp=False): + batch_size = len(prompts) + + cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16) + + text_embeddings = encode_text(model, prompts) + uncond_embedding = encode_text(model, [""] * batch_size) + + if etas is None: etas = 0 + if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps + assert len(etas) == model.scheduler.num_inference_steps + timesteps = model.scheduler.timesteps.to(model.device) + + xt = xT.expand(batch_size, -1, -1, -1) + op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:] + + t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} + + for t in op: + idx = t_to_idx[int(t)] + ## Unconditional embedding + with torch.no_grad(): + uncond_out = model.unet.forward(xt, timestep=t, + encoder_hidden_states=uncond_embedding) + + ## Conditional embedding + if prompts: + with torch.no_grad(): + cond_out = model.unet.forward(xt, timestep=t, + encoder_hidden_states=text_embeddings) + + z = zs[idx] if not zs is None else None + z = z.expand(batch_size, -1, -1, -1) + if prompts: + ## classifier free guidance + noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample) + else: + noise_pred = uncond_out.sample + # 2. compute less noisy image and set x_t -> x_t-1 + xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z) + if controller is not None: + xt = controller.step_callback(xt) + return xt, zs diff --git a/ai-toolkit/toolkit/ip_adapter.py b/ai-toolkit/toolkit/ip_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8ead73be7852a855c6fc56cf230b350db2cafe07 --- /dev/null +++ b/ai-toolkit/toolkit/ip_adapter.py @@ -0,0 +1,1302 @@ +import random + +import torch +import sys + +from diffusers import Transformer2DModel +from torch import nn +from torch.nn import Parameter +from torch.nn.modules.module import T +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.zipper_resampler import ZipperResampler +from toolkit.saving import load_ip_adapter_model +from toolkit.train_tools import get_torch_dtype +from toolkit.util.inverse_cfg import inverse_classifier_guidance + +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional +from collections import OrderedDict +from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel +from toolkit.resampler import Resampler +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref +from diffusers import FluxTransformer2DModel + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + AutoImageProcessor, + ConvNextV2ForImageClassification, + ConvNextForImageClassification, + ConvNextImageProcessor +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +import torch.nn.functional as F + + +class MLPProjModelClipFace(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.norm = torch.nn.LayerNorm(id_embeddings_dim) + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), + ) + # Initialize the last linear layer weights near zero + torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01) + torch.nn.init.zeros_(self.proj[2].bias) + # # Custom initialization for LayerNorm to output near zero + # torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero + # torch.nn.init.zeros_(self.norm.bias) # Bias to zero + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return x + + +class CustomIPAttentionProcessor(IPAttnProcessor2_0): + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False): + super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.train_scaler = train_scaler + if train_scaler: + if full_token_scaler: + self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999) + else: + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999) + # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + self.ip_scaler.requires_grad_(True) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if is_active: + # since we are removing tokens, we need to adjust the sequence length + sequence_length = sequence_length - self.num_tokens + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + # will be none if disabled + if not is_active: + ip_hidden_states = None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + try: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + except Exception as e: + print(e) + raise e + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # will be none if disabled + if ip_hidden_states is not None: + # apply scaler + if self.train_scaler: + weight = self.ip_scaler + # reshape to (1, self.num_tokens, 1) + weight = weight.view(1, -1, 1) + ip_hidden_states = ip_hidden_states * weight + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + scale = self.scale + hidden_states = hidden_states + scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + # this ensures that the ip_scaler is not changed when we load the model + # def _apply(self, fn): + # if hasattr(self, "ip_scaler"): + # # Overriding the _apply method to prevent the special_parameter from changing dtype + # self.ip_scaler = fn(self.ip_scaler) + # # Temporarily set the special_parameter to None to exclude it from default _apply processing + # ip_scaler = self.ip_scaler + # self.ip_scaler = None + # super(CustomIPAttentionProcessor, self)._apply(fn) + # # Restore the special_parameter after the default _apply processing + # self.ip_scaler = ip_scaler + # return self + # else: + # return super(CustomIPAttentionProcessor, self)._apply(fn) + + +class CustomIPFluxAttnProcessor2_0(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, + full_token_scaler=False): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.train_scaler = train_scaler + self.num_tokens = num_tokens + if train_scaler: + if full_token_scaler: + self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999) + else: + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999) + # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + self.ip_scaler.requires_grad_(True) + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + is_active = self.adapter_ref().is_active + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # begin ip adapter + if not is_active: + ip_hidden_states = None + else: + # get ip hidden states. Should be stored + ip_hidden_states = self.adapter_ref().last_conditional + # add unconditional to front if it exists + if ip_hidden_states.shape[0] * 2 == batch_size: + if self.adapter_ref().last_unconditional is None: + raise ValueError("Unconditional is None but should not be") + ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0) + + if ip_hidden_states is not None: + # apply scaler + if self.train_scaler: + weight = self.ip_scaler + # reshape to (1, self.num_tokens, 1) + weight = weight.view(1, -1, 1) + ip_hidden_states = ip_hidden_states * weight + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + scale = self.scale + hidden_states = hidden_states + scale * ip_hidden_states + # end ip adapter + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + +# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.preprocessor: Optional[CLIPImagePreProcessor] = None + self.input_size = 224 + self.clip_noise_zero = True + self.unconditional: torch.Tensor = None + + self.last_conditional: torch.Tensor = None + self.last_unconditional: torch.Tensor = None + + self.additional_loss = None + if self.config.image_encoder_arch.startswith("clip"): + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SiglipImageProcessor() + self.image_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SAFEImageProcessor() + self.image_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.safe_tokens, + num_vectors=sd.unet.config['cross_attention_dim'], + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnext': + try: + self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ConvNextImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.image_encoder = ConvNextForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnextv2': + try: + self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ConvNextImageProcessor( + size=512, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + self.image_encoder = ConvNextV2ForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + else: + raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + + if not self.config.train_image_encoder: + # compile it + print('Compiling image encoder') + #torch.compile(self.image_encoder, fullgraph=True) + + self.input_size = self.image_encoder.config.image_size + + if self.config.quad_image: # 4x4 image + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.image_encoder.config.image_size * 2 + + # update the preprocessor so images come in at the right size + if 'height' in self.clip_image_processor.size: + self.clip_image_processor.size['height'] = preprocessor_input_size + self.clip_image_processor.size['width'] = preprocessor_input_size + elif hasattr(self.clip_image_processor, 'crop_size'): + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size + + if self.config.image_encoder_arch == 'clip+': + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.image_encoder.config.image_size * 4 + + # update the preprocessor so images come in at the right size + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size + + self.preprocessor = CLIPImagePreProcessor( + input_size=preprocessor_input_size, + clip_input_size=self.image_encoder.config.image_size, + ) + if not self.config.image_encoder_arch == 'safe': + if 'height' in self.clip_image_processor.size: + self.input_size = self.clip_image_processor.size['height'] + elif hasattr(self.clip_image_processor, 'crop_size'): + self.input_size = self.clip_image_processor.crop_size['height'] + elif 'shortest_edge' in self.clip_image_processor.size.keys(): + self.input_size = self.clip_image_processor.size['shortest_edge'] + else: + raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}") + self.current_scale = 1.0 + self.is_active = True + is_pixart = sd.is_pixart + is_flux = sd.is_flux + if adapter_config.type == 'ip': + # ip-adapter + image_proj_model = ImageProjModel( + cross_attention_dim=sd.unet.config['cross_attention_dim'], + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.config.num_tokens, # usually 4 + ) + elif adapter_config.type == 'ip_clip_face': + cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim'] + image_proj_model = MLPProjModelClipFace( + cross_attention_dim=cross_attn_dim, + id_embeddings_dim=self.image_encoder.config.projection_dim, + num_tokens=self.config.num_tokens, # usually 4 + ) + elif adapter_config.type == 'ip+': + heads = 12 if not sd.is_xl else 20 + if is_flux: + dim = 1280 + else: + dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith( + 'convnext') else \ + self.image_encoder.config.hidden_sizes[-1] + + image_encoder_state_dict = self.image_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + max_seq_len = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + max_seq_len = int( + image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if is_pixart: + heads = 20 + dim = 1280 + output_dim = 4096 + elif is_flux: + heads = 20 + dim = 1280 + output_dim = 3072 + else: + output_dim = sd.unet.config['cross_attention_dim'] + + if self.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + + # ip-adapter-plus + image_proj_model = Resampler( + dim=dim, + depth=4, + dim_head=64, + heads=heads, + num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len, + embedding_dim=embedding_dim, + max_seq_len=max_seq_len, + output_dim=output_dim, + ff_mult=4 + ) + elif adapter_config.type == 'ipz': + dim = sd.unet.config['cross_attention_dim'] + if hasattr(self.image_encoder.config, 'hidden_sizes'): + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + else: + embedding_dim = self.image_encoder.config.target_hidden_size + + image_encoder_state_dict = self.image_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if self.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + + is_conv_next = self.config.image_encoder_arch.startswith('convnext') + + out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens + # ip-adapter-plus + image_proj_model = ZipperResampler( + in_size=embedding_dim, + in_tokens=in_tokens, + out_size=dim, + out_tokens=out_tokens, + hidden_size=embedding_dim, + hidden_tokens=in_tokens, + # num_blocks=1 if not is_conv_next else 2, + num_blocks=1 if not is_conv_next else 2, + is_conv_input=is_conv_next + ) + elif adapter_config.type == 'ilora': + # we apply the clip encodings to the LoRA + image_proj_model = None + else: + raise ValueError(f"unknown adapter type: {adapter_config.type}") + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + elif is_flux: + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn") + + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + attn_processor_names = [] + + blocks = [] + transformer_blocks = [] + for name in attn_processor_keys: + name_split = name.split(".") + block_name = f"{name_split[0]}.{name_split[1]}" + transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1 + if transformer_idx >= 0: + transformer_name = ".".join(name_split[:2]) + transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2]) + if transformer_name not in transformer_blocks: + transformer_blocks.append(transformer_name) + + + if block_name not in blocks: + blocks.append(block_name) + if is_flux: + cross_attention_dim = None + else: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ + sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer") or name.startswith("single_transformer"): + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None and not is_flux: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + + # if quantized, we need to scale the weights + if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux: + # is quantized + + k_weight = torch.randn(hidden_size, hidden_size) * 0.01 + v_weight = torch.randn(hidden_size, hidden_size) * 0.01 + k_weight = k_weight.to(self.sd_ref().torch_dtype) + v_weight = v_weight.to(self.sd_ref().torch_dtype) + else: + k_weight = unet_sd[layer_name + ".to_k.weight"] + v_weight = unet_sd[layer_name + ".to_v.weight"] + + weights = { + "to_k_ip.weight": k_weight, + "to_v_ip.weight": v_weight + } + + if is_flux: + attn_procs[name] = CustomIPFluxAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler, + full_token_scaler=False + ) + else: + attn_procs[name] = CustomIPAttentionProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler, + # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler + full_token_scaler=False + ) + if self.sd_ref().is_pixart or self.sd_ref().is_flux: + # pixart is much more sensitive + weights = { + "to_k_ip.weight": weights["to_k_ip.weight"] * 0.01, + "to_v_ip.weight": weights["to_v_ip.weight"] * 0.01, + } + + attn_procs[name].load_state_dict(weights, strict=False) + attn_processor_names.append(name) + print(f"Attn Processors") + print(attn_processor_names) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn2.processor for i in + range(len(transformer.transformer_blocks)) + ]) + elif self.sd_ref().is_flux: + # we have to set them ourselves + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + sd.adapter = self + self.unet_ref: weakref.ref = weakref.ref(sd.unet) + self.image_proj_model = image_proj_model + # load the weights if we have some + if self.config.name_or_path: + loaded_state_dict = load_ip_adapter_model( + self.config.name_or_path, + device='cpu', + dtype=sd.torch_dtype + ) + self.load_state_dict(loaded_state_dict) + + self.set_scale(1.0) + + if self.config.train_image_encoder: + self.image_encoder.train() + self.image_encoder.requires_grad_(True) + + # premake a unconditional + zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16) + self.unconditional = self.clip_image_processor( + images=zerod, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.image_encoder.to(*args, **kwargs) + self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + if self.preprocessor is not None: + self.preprocessor.to(*args, **kwargs) + return self + + # def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): + # self.image_proj_model.load_state_dict(state_dict["image_proj"]) + # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + # ip_layers.load_state_dict(state_dict["ip_adapter"]) + # if self.config.train_image_encoder and 'image_encoder' in state_dict: + # self.image_encoder.load_state_dict(state_dict["image_encoder"]) + # if self.preprocessor is not None and 'preprocessor' in state_dict: + # self.preprocessor.load_state_dict(state_dict["preprocessor"]) + + # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): + # self.load_ip_adapter(state_dict) + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + if self.config.train_only_image_encoder: + return self.image_encoder.state_dict() + if self.config.train_scaler: + state_dict["ip_scale"] = self.adapter_modules.state_dict() + # remove items that are not scalers + for key in list(state_dict["ip_scale"].keys()): + if not key.endswith("ip_scaler"): + del state_dict["ip_scale"][key] + return state_dict + + state_dict["image_proj"] = self.image_proj_model.state_dict() + state_dict["ip_adapter"] = self.adapter_modules.state_dict() + # handle merge scaler training + if self.config.merge_scaler: + for key in list(state_dict["ip_adapter"].keys()): + if key.endswith("ip_scaler"): + # merge in the scaler so we dont have to save it and it will be compatible with other ip adapters + scale = state_dict["ip_adapter"][key].clone() + + key_start = key.split(".")[-2] + # reshape to (1, 1) + scale = scale.view(1, 1) + del state_dict["ip_adapter"][key] + # find the to_k_ip and to_v_ip keys + for key2 in list(state_dict["ip_adapter"].keys()): + if key2.endswith(f"{key_start}.to_k_ip.weight"): + state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale + if key2.endswith(f"{key_start}.to_v_ip.weight"): + state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale + + if self.config.train_image_encoder: + state_dict["image_encoder"] = self.image_encoder.state_dict() + if self.preprocessor is not None: + state_dict["preprocessor"] = self.preprocessor.state_dict() + return state_dict + + def get_scale(self): + return self.current_scale + + def set_scale(self, scale): + self.current_scale = scale + if not self.sd_ref().is_pixart and not self.sd_ref().is_flux: + for attn_processor in self.sd_ref().unet.attn_processors.values(): + if isinstance(attn_processor, CustomIPAttentionProcessor): + attn_processor.scale = scale + + # @torch.no_grad() + # def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], + # drop=False) -> torch.Tensor: + # # todo: add support for sdxl + # if isinstance(pil_image, Image.Image): + # pil_image = [pil_image] + # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + # clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + # if drop: + # clip_image = clip_image * 0 + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + # return clip_image_embeds + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.image_encoder.to(*args, **kwargs) + self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + if self.preprocessor is not None: + self.preprocessor.to(*args, **kwargs) + return self + + def parse_clip_image_embeds_from_cache( + self, + image_embeds_list: List[dict], # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + quad_count=4, + ): + with torch.no_grad(): + device = self.sd_ref().unet.device + clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0) + + if self.config.quad_image: + # get the outputs of the quat + chunks = clip_image_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + clip_image_embeds = chunk_sum / quad_count + + clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + return clip_image_embeds + + def get_empty_clip_image(self, batch_size: int) -> torch.Tensor: + with torch.no_grad(): + tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device) + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + return clip_image.detach() + + def get_clip_image_embeds_from_tensors( + self, + tensors_0_1: torch.Tensor, + drop=False, + is_training=False, + has_been_preprocessed=False, + quad_count=4, + cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative + ) -> torch.Tensor: + if self.sd_ref().unet.device != self.device: + self.to(self.sd_ref().unet.device) + if self.sd_ref().unet.device != self.image_encoder.device: + self.to(self.sd_ref().unet.device) + if not self.config.train: + is_training = False + uncond_clip = None + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + # unconditional + if drop: + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + if drop: + # scale the noise down + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + # if drop: + # clip_image = clip_image * 0 + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.image_encoder.train() + clip_image = clip_image.requires_grad_(True) + if self.preprocessor is not None: + clip_image = self.preprocessor(clip_image) + clip_output = self.image_encoder( + clip_image, + output_hidden_states=True + ) + else: + self.image_encoder.eval() + if self.preprocessor is not None: + clip_image = self.preprocessor(clip_image) + clip_output = self.image_encoder( + clip_image, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + clip_image_embeds = clip_output.image_embeds + + if self.config.adapter_type == "clip_face": + l2_norm = torch.norm(clip_image_embeds, p=2) + clip_image_embeds = clip_image_embeds / l2_norm + + if self.config.image_encoder_arch.startswith('convnext'): + # flatten the width height layers to make the token space + clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1) + # rearrange to (batch, tokens, size) + clip_image_embeds = clip_image_embeds.permute(0, 2, 1) + + # apply unconditional if doing cfg on embeds + with torch.no_grad(): + if cfg_embed_strength is not None: + uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = uncond_clip.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + uncond_clip = torch.cat(to_cat, dim=0).detach() + uncond_clip_output = self.image_encoder( + uncond_clip, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1] + else: + uncond_clip_output_embeds = uncond_clip_output.image_embeds + if self.config.adapter_type == "clip_face": + l2_norm = torch.norm(uncond_clip_output_embeds, p=2) + uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm + + uncond_clip_output_embeds = uncond_clip_output_embeds.detach() + + + # apply inverse cfg + clip_image_embeds = inverse_classifier_guidance( + clip_image_embeds, + uncond_clip_output_embeds, + cfg_embed_strength + ) + + + if self.config.quad_image: + # get the outputs of the quat + chunks = clip_image_embeds.chunk(quad_count, dim=0) + if self.config.train_image_encoder and is_training: + # perform a loss across all chunks this will teach the vision encoder to + # identify similarities in our pairs of images and ignore things that do not make them similar + num_losses = 0 + total_loss = None + for chunk in chunks: + for chunk2 in chunks: + if chunk is not chunk2: + loss = F.mse_loss(chunk, chunk2) + if total_loss is None: + total_loss = loss + else: + total_loss = total_loss + loss + num_losses += 1 + if total_loss is not None: + total_loss = total_loss / num_losses + total_loss = total_loss * 1e-2 + if self.additional_loss is not None: + total_loss = total_loss + self.additional_loss + self.additional_loss = total_loss + + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + clip_image_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + clip_image_embeds = clip_image_embeds.detach() + + return clip_image_embeds + + # use drop for prompt dropout, or negatives + def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + if self.sd_ref().is_flux: + # do not attach to text embeds for flux, we will save and grab them as it messes + # with the RoPE to have them in the same tensor + if is_unconditional: + self.last_unconditional = image_prompt_embeds + else: + self.last_conditional = image_prompt_embeds + else: + embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) + return embeddings + + def train(self: T, mode: bool = True) -> T: + if self.config.train_image_encoder: + self.image_encoder.train(mode) + if not self.config.train_only_image_encoder: + for attn_processor in self.adapter_modules: + attn_processor.train(mode) + if self.image_proj_model is not None: + self.image_proj_model.train(mode) + return super().train(mode) + + def get_parameter_groups(self, adapter_lr): + param_groups = [] + # when training just scaler, we do not train anything else + if not self.config.train_scaler: + param_groups.append({ + "params": list(self.get_non_scaler_parameters()), + "lr": adapter_lr, + }) + if self.config.train_scaler or self.config.merge_scaler: + scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr + param_groups.append({ + "params": list(self.get_scaler_parameters()), + "lr": scaler_lr, + }) + return param_groups + + def get_scaler_parameters(self): + # only get the scalera from the adapter modules + for attn_processor in self.adapter_modules: + # only get the scaler + # check if it has ip_scaler attribute + if hasattr(attn_processor, "ip_scaler"): + scaler_param = attn_processor.ip_scaler + yield scaler_param + + def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.train_only_image_encoder: + if self.config.train_only_image_encoder_positional_embedding: + yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse) + else: + yield from self.image_encoder.parameters(recurse) + return + if self.config.train_scaler: + # no params + return + + for attn_processor in self.adapter_modules: + if self.config.train_scaler or self.config.merge_scaler: + # todo remove scaler + if hasattr(attn_processor, "to_k_ip"): + # yield the linear layer + yield from attn_processor.to_k_ip.parameters(recurse) + if hasattr(attn_processor, "to_v_ip"): + # yield the linear layer + yield from attn_processor.to_v_ip.parameters(recurse) + else: + yield from attn_processor.parameters(recurse) + yield from self.image_proj_model.parameters(recurse) + if self.config.train_image_encoder: + yield from self.image_encoder.parameters(recurse) + if self.preprocessor is not None: + yield from self.preprocessor.parameters(recurse) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield from self.get_non_scaler_parameters(recurse) + if self.config.train_scaler or self.config.merge_scaler: + yield from self.get_scaler_parameters() + + def merge_in_weights(self, state_dict: Mapping[str, Any]): + # merge in img_proj weights + current_img_proj_state_dict = self.image_proj_model.state_dict() + for key, value in state_dict["image_proj"].items(): + if key in current_img_proj_state_dict: + current_shape = current_img_proj_state_dict[key].shape + new_shape = value.shape + if current_shape != new_shape: + try: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_img_proj_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + except RuntimeError as e: + print(e) + print( + f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + + if len(current_shape) == 1: + current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]] + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[ + :current_shape[0], + :current_shape[1]] + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], + :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + else: + current_img_proj_state_dict[key] = value + self.image_proj_model.load_state_dict(current_img_proj_state_dict) + + # merge in ip adapter weights + current_ip_adapter_state_dict = self.adapter_modules.state_dict() + for key, value in state_dict["ip_adapter"].items(): + if key in current_ip_adapter_state_dict: + current_shape = current_ip_adapter_state_dict[key].shape + new_shape = value.shape + if current_shape != new_shape: + try: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_ip_adapter_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + except RuntimeError as e: + print(e) + print( + f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + + if (len(current_shape) == 1): + current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]] + elif (len(current_shape) == 2): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[ + :current_shape[ + 0], + :current_shape[ + 1]] + elif (len(current_shape) == 3): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], + :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif (len(current_shape) == 4): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + + else: + current_ip_adapter_state_dict[key] = value + self.adapter_modules.load_state_dict(current_ip_adapter_state_dict) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + if self.config.train_scaler and 'ip_scale' in state_dict: + self.adapter_modules.load_state_dict(state_dict["ip_scale"], strict=False) + if 'ip_adapter' in state_dict: + try: + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + except Exception as e: + print(e) + print("could not load ip adapter weights, trying to merge in weights") + self.merge_in_weights(state_dict) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + if self.preprocessor is not None and 'preprocessor' in state_dict: + self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict) + + if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict: + # we are loading pure clip weights. + self.image_encoder.load_state_dict(state_dict, strict=strict) + + def enable_gradient_checkpointing(self): + if hasattr(self.image_encoder, "enable_gradient_checkpointing"): + self.image_encoder.enable_gradient_checkpointing() + elif hasattr(self.image_encoder, 'gradient_checkpointing'): + self.image_encoder.gradient_checkpointing = True diff --git a/ai-toolkit/toolkit/job.py b/ai-toolkit/toolkit/job.py new file mode 100644 index 0000000000000000000000000000000000000000..dc274fb798efb5780056bd2134eb2940a608a98c --- /dev/null +++ b/ai-toolkit/toolkit/job.py @@ -0,0 +1,44 @@ +from typing import Union, OrderedDict + +from toolkit.config import get_config + + +def get_job( + config_path: Union[str, dict, OrderedDict], + name=None +): + config = get_config(config_path, name) + if not config['job']: + raise ValueError('config file is invalid. Missing "job" key') + + job = config['job'] + if job == 'extract': + from jobs import ExtractJob + return ExtractJob(config) + if job == 'train': + from jobs import TrainJob + return TrainJob(config) + if job == 'mod': + from jobs import ModJob + return ModJob(config) + if job == 'generate': + from jobs import GenerateJob + return GenerateJob(config) + if job == 'extension': + from jobs import ExtensionJob + return ExtensionJob(config) + + # elif job == 'train': + # from jobs import TrainJob + # return TrainJob(config) + else: + raise ValueError(f'Unknown job type {job}') + + +def run_job( + config: Union[str, dict, OrderedDict], + name=None +): + job = get_job(config, name) + job.run() + job.cleanup() diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_refiner.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_refiner.json new file mode 100644 index 0000000000000000000000000000000000000000..4c7525d8804da9ec92b7f87bc01741d4372ac83d --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_refiner.json @@ -0,0 +1,3498 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.0.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.0.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight", + "2560:, :" + ] + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_refiner_unmatched.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_refiner_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..cb5aba0a543c8ad50094abb3f99e266336908aaa --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_refiner_unmatched.json @@ -0,0 +1,27 @@ +{ + "ldm": { + "conditioner.embedders.0.model.logit_scale": { + "shape": [], + "min": 4.60546875, + "max": 4.60546875 + }, + "conditioner.embedders.0.model.text_projection": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + }, + "diffusers": { + "te1_text_projection.weight": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_sd1.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_sd1.json new file mode 100644 index 0000000000000000000000000000000000000000..8f04f753ac6656fdc2a2d44d8d07ebc7db184689 --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_sd1.json @@ -0,0 +1,1234 @@ +{ + "ldm_diffusers_keymap": { + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "te_text_model.embeddings.position_embedding.weight", + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te_text_model.encoder.layers.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te_text_model.encoder.layers.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te_text_model.encoder.layers.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te_text_model.encoder.layers.0.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te_text_model.encoder.layers.0.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te_text_model.encoder.layers.0.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te_text_model.encoder.layers.0.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te_text_model.encoder.layers.0.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te_text_model.encoder.layers.0.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te_text_model.encoder.layers.0.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te_text_model.encoder.layers.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te_text_model.encoder.layers.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te_text_model.encoder.layers.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te_text_model.encoder.layers.1.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te_text_model.encoder.layers.1.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te_text_model.encoder.layers.1.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te_text_model.encoder.layers.1.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te_text_model.encoder.layers.1.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te_text_model.encoder.layers.1.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te_text_model.encoder.layers.1.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te_text_model.encoder.layers.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te_text_model.encoder.layers.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te_text_model.encoder.layers.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te_text_model.encoder.layers.10.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te_text_model.encoder.layers.10.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te_text_model.encoder.layers.10.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te_text_model.encoder.layers.10.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te_text_model.encoder.layers.10.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te_text_model.encoder.layers.10.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te_text_model.encoder.layers.10.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te_text_model.encoder.layers.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te_text_model.encoder.layers.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te_text_model.encoder.layers.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te_text_model.encoder.layers.11.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te_text_model.encoder.layers.11.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te_text_model.encoder.layers.11.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te_text_model.encoder.layers.11.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te_text_model.encoder.layers.11.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te_text_model.encoder.layers.11.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te_text_model.encoder.layers.11.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te_text_model.encoder.layers.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te_text_model.encoder.layers.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te_text_model.encoder.layers.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te_text_model.encoder.layers.2.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te_text_model.encoder.layers.2.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te_text_model.encoder.layers.2.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te_text_model.encoder.layers.2.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te_text_model.encoder.layers.2.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te_text_model.encoder.layers.2.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te_text_model.encoder.layers.2.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te_text_model.encoder.layers.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te_text_model.encoder.layers.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te_text_model.encoder.layers.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te_text_model.encoder.layers.3.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te_text_model.encoder.layers.3.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te_text_model.encoder.layers.3.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te_text_model.encoder.layers.3.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te_text_model.encoder.layers.3.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te_text_model.encoder.layers.3.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te_text_model.encoder.layers.3.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te_text_model.encoder.layers.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te_text_model.encoder.layers.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te_text_model.encoder.layers.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te_text_model.encoder.layers.4.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te_text_model.encoder.layers.4.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te_text_model.encoder.layers.4.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te_text_model.encoder.layers.4.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te_text_model.encoder.layers.4.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te_text_model.encoder.layers.4.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te_text_model.encoder.layers.4.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te_text_model.encoder.layers.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te_text_model.encoder.layers.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te_text_model.encoder.layers.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te_text_model.encoder.layers.5.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te_text_model.encoder.layers.5.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te_text_model.encoder.layers.5.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te_text_model.encoder.layers.5.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te_text_model.encoder.layers.5.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te_text_model.encoder.layers.5.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te_text_model.encoder.layers.5.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te_text_model.encoder.layers.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te_text_model.encoder.layers.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te_text_model.encoder.layers.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te_text_model.encoder.layers.6.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te_text_model.encoder.layers.6.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te_text_model.encoder.layers.6.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te_text_model.encoder.layers.6.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te_text_model.encoder.layers.6.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te_text_model.encoder.layers.6.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te_text_model.encoder.layers.6.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te_text_model.encoder.layers.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te_text_model.encoder.layers.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te_text_model.encoder.layers.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te_text_model.encoder.layers.7.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te_text_model.encoder.layers.7.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te_text_model.encoder.layers.7.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te_text_model.encoder.layers.7.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te_text_model.encoder.layers.7.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te_text_model.encoder.layers.7.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te_text_model.encoder.layers.7.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te_text_model.encoder.layers.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te_text_model.encoder.layers.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te_text_model.encoder.layers.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te_text_model.encoder.layers.8.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te_text_model.encoder.layers.8.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te_text_model.encoder.layers.8.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te_text_model.encoder.layers.8.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te_text_model.encoder.layers.8.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te_text_model.encoder.layers.8.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te_text_model.encoder.layers.8.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te_text_model.encoder.layers.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te_text_model.encoder.layers.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te_text_model.encoder.layers.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te_text_model.encoder.layers.9.mlp.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te_text_model.encoder.layers.9.mlp.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te_text_model.encoder.layers.9.mlp.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te_text_model.encoder.layers.9.mlp.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te_text_model.encoder.layers.9.self_attn.k_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te_text_model.encoder.layers.9.self_attn.k_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te_text_model.encoder.layers.9.self_attn.q_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te_text_model.encoder.layers.9.self_attn.q_proj.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te_text_model.encoder.layers.9.self_attn.v_proj.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te_text_model.encoder.layers.9.self_attn.v_proj.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "te_text_model.final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "te_text_model.final_layer_norm.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.1.1.norm.bias": "unet_down_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.1.1.norm.weight": "unet_down_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.1.1.proj_in.bias": "unet_down_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.1.1.proj_in.weight": "unet_down_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.1.1.proj_out.bias": "unet_down_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.1.1.proj_out.weight": "unet_down_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.1.norm.bias": "unet_down_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.2.1.norm.weight": "unet_down_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.2.1.proj_in.bias": "unet_down_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.2.1.proj_in.weight": "unet_down_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.2.1.proj_out.bias": "unet_down_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.2.1.proj_out.weight": "unet_down_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.1.norm.bias": "unet_up_blocks.3.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.10.1.norm.weight": "unet_up_blocks.3.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.10.1.proj_in.bias": "unet_up_blocks.3.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.10.1.proj_in.weight": "unet_up_blocks.3.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.10.1.proj_out.bias": "unet_up_blocks.3.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.10.1.proj_out.weight": "unet_up_blocks.3.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.1.norm.bias": "unet_up_blocks.3.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.11.1.norm.weight": "unet_up_blocks.3.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.11.1.proj_in.bias": "unet_up_blocks.3.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.11.1.proj_in.weight": "unet_up_blocks.3.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.11.1.proj_out.bias": "unet_up_blocks.3.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.11.1.proj_out.weight": "unet_up_blocks.3.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.9.1.norm.bias": "unet_up_blocks.3.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.9.1.norm.weight": "unet_up_blocks.3.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.9.1.proj_in.bias": "unet_up_blocks.3.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.9.1.proj_in.weight": "unet_up_blocks.3.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.9.1.proj_out.bias": "unet_up_blocks.3.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.9.1.proj_out.weight": "unet_up_blocks.3.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": {}, + "diffusers_ldm_operator_map": {} +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_sd2.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_sd2.json new file mode 100644 index 0000000000000000000000000000000000000000..868facaf5b6119f5d3a82d369fe509b82da1f551 --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_sd2.json @@ -0,0 +1,2424 @@ +{ + "ldm_diffusers_keymap": { + "cond_stage_model.model.ln_final.bias": "te_text_model.final_layer_norm.bias", + "cond_stage_model.model.ln_final.weight": "te_text_model.final_layer_norm.weight", + "cond_stage_model.model.positional_embedding": "te_text_model.embeddings.position_embedding.weight", + "cond_stage_model.model.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight", + "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.0.ln_1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.0.ln_1.weight": "te_text_model.encoder.layers.0.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.0.ln_2.bias": "te_text_model.encoder.layers.0.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.0.ln_2.weight": "te_text_model.encoder.layers.0.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.bias": "te_text_model.encoder.layers.0.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.1.ln_1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.1.ln_1.weight": "te_text_model.encoder.layers.1.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.1.ln_2.bias": "te_text_model.encoder.layers.1.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.1.ln_2.weight": "te_text_model.encoder.layers.1.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.bias": "te_text_model.encoder.layers.1.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.10.ln_1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.10.ln_1.weight": "te_text_model.encoder.layers.10.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.10.ln_2.bias": "te_text_model.encoder.layers.10.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.10.ln_2.weight": "te_text_model.encoder.layers.10.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.bias": "te_text_model.encoder.layers.10.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.11.ln_1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.11.ln_1.weight": "te_text_model.encoder.layers.11.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.11.ln_2.bias": "te_text_model.encoder.layers.11.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.11.ln_2.weight": "te_text_model.encoder.layers.11.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.bias": "te_text_model.encoder.layers.11.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias": "te_text_model.encoder.layers.12.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight": "te_text_model.encoder.layers.12.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.12.ln_1.bias": "te_text_model.encoder.layers.12.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.12.ln_1.weight": "te_text_model.encoder.layers.12.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.12.ln_2.bias": "te_text_model.encoder.layers.12.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.12.ln_2.weight": "te_text_model.encoder.layers.12.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.bias": "te_text_model.encoder.layers.12.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight": "te_text_model.encoder.layers.12.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias": "te_text_model.encoder.layers.12.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight": "te_text_model.encoder.layers.12.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias": "te_text_model.encoder.layers.13.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight": "te_text_model.encoder.layers.13.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.13.ln_1.bias": "te_text_model.encoder.layers.13.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.13.ln_1.weight": "te_text_model.encoder.layers.13.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.13.ln_2.bias": "te_text_model.encoder.layers.13.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.13.ln_2.weight": "te_text_model.encoder.layers.13.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.bias": "te_text_model.encoder.layers.13.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight": "te_text_model.encoder.layers.13.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias": "te_text_model.encoder.layers.13.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight": "te_text_model.encoder.layers.13.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias": "te_text_model.encoder.layers.14.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight": "te_text_model.encoder.layers.14.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.14.ln_1.bias": "te_text_model.encoder.layers.14.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.14.ln_1.weight": "te_text_model.encoder.layers.14.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.14.ln_2.bias": "te_text_model.encoder.layers.14.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.14.ln_2.weight": "te_text_model.encoder.layers.14.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.bias": "te_text_model.encoder.layers.14.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight": "te_text_model.encoder.layers.14.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias": "te_text_model.encoder.layers.14.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight": "te_text_model.encoder.layers.14.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias": "te_text_model.encoder.layers.15.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight": "te_text_model.encoder.layers.15.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.15.ln_1.bias": "te_text_model.encoder.layers.15.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.15.ln_1.weight": "te_text_model.encoder.layers.15.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.15.ln_2.bias": "te_text_model.encoder.layers.15.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.15.ln_2.weight": "te_text_model.encoder.layers.15.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.bias": "te_text_model.encoder.layers.15.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight": "te_text_model.encoder.layers.15.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias": "te_text_model.encoder.layers.15.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight": "te_text_model.encoder.layers.15.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias": "te_text_model.encoder.layers.16.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight": "te_text_model.encoder.layers.16.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.16.ln_1.bias": "te_text_model.encoder.layers.16.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.16.ln_1.weight": "te_text_model.encoder.layers.16.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.16.ln_2.bias": "te_text_model.encoder.layers.16.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.16.ln_2.weight": "te_text_model.encoder.layers.16.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.bias": "te_text_model.encoder.layers.16.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight": "te_text_model.encoder.layers.16.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias": "te_text_model.encoder.layers.16.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight": "te_text_model.encoder.layers.16.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias": "te_text_model.encoder.layers.17.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight": "te_text_model.encoder.layers.17.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.17.ln_1.bias": "te_text_model.encoder.layers.17.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.17.ln_1.weight": "te_text_model.encoder.layers.17.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.17.ln_2.bias": "te_text_model.encoder.layers.17.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.17.ln_2.weight": "te_text_model.encoder.layers.17.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.bias": "te_text_model.encoder.layers.17.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight": "te_text_model.encoder.layers.17.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias": "te_text_model.encoder.layers.17.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight": "te_text_model.encoder.layers.17.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias": "te_text_model.encoder.layers.18.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight": "te_text_model.encoder.layers.18.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.18.ln_1.bias": "te_text_model.encoder.layers.18.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.18.ln_1.weight": "te_text_model.encoder.layers.18.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.18.ln_2.bias": "te_text_model.encoder.layers.18.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.18.ln_2.weight": "te_text_model.encoder.layers.18.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.bias": "te_text_model.encoder.layers.18.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight": "te_text_model.encoder.layers.18.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias": "te_text_model.encoder.layers.18.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight": "te_text_model.encoder.layers.18.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias": "te_text_model.encoder.layers.19.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight": "te_text_model.encoder.layers.19.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.19.ln_1.bias": "te_text_model.encoder.layers.19.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.19.ln_1.weight": "te_text_model.encoder.layers.19.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.19.ln_2.bias": "te_text_model.encoder.layers.19.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.19.ln_2.weight": "te_text_model.encoder.layers.19.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.bias": "te_text_model.encoder.layers.19.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight": "te_text_model.encoder.layers.19.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias": "te_text_model.encoder.layers.19.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight": "te_text_model.encoder.layers.19.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.2.ln_1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.2.ln_1.weight": "te_text_model.encoder.layers.2.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.2.ln_2.bias": "te_text_model.encoder.layers.2.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.2.ln_2.weight": "te_text_model.encoder.layers.2.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.bias": "te_text_model.encoder.layers.2.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias": "te_text_model.encoder.layers.20.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight": "te_text_model.encoder.layers.20.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.20.ln_1.bias": "te_text_model.encoder.layers.20.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.20.ln_1.weight": "te_text_model.encoder.layers.20.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.20.ln_2.bias": "te_text_model.encoder.layers.20.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.20.ln_2.weight": "te_text_model.encoder.layers.20.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.bias": "te_text_model.encoder.layers.20.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight": "te_text_model.encoder.layers.20.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias": "te_text_model.encoder.layers.20.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight": "te_text_model.encoder.layers.20.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias": "te_text_model.encoder.layers.21.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight": "te_text_model.encoder.layers.21.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.21.ln_1.bias": "te_text_model.encoder.layers.21.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.21.ln_1.weight": "te_text_model.encoder.layers.21.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.21.ln_2.bias": "te_text_model.encoder.layers.21.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.21.ln_2.weight": "te_text_model.encoder.layers.21.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.bias": "te_text_model.encoder.layers.21.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight": "te_text_model.encoder.layers.21.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias": "te_text_model.encoder.layers.21.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight": "te_text_model.encoder.layers.21.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias": "te_text_model.encoder.layers.22.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight": "te_text_model.encoder.layers.22.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.22.ln_1.bias": "te_text_model.encoder.layers.22.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.22.ln_1.weight": "te_text_model.encoder.layers.22.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.22.ln_2.bias": "te_text_model.encoder.layers.22.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.22.ln_2.weight": "te_text_model.encoder.layers.22.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.bias": "te_text_model.encoder.layers.22.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight": "te_text_model.encoder.layers.22.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias": "te_text_model.encoder.layers.22.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight": "te_text_model.encoder.layers.22.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.3.ln_1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.3.ln_1.weight": "te_text_model.encoder.layers.3.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.3.ln_2.bias": "te_text_model.encoder.layers.3.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.3.ln_2.weight": "te_text_model.encoder.layers.3.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.bias": "te_text_model.encoder.layers.3.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.4.ln_1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.4.ln_1.weight": "te_text_model.encoder.layers.4.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.4.ln_2.bias": "te_text_model.encoder.layers.4.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.4.ln_2.weight": "te_text_model.encoder.layers.4.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.bias": "te_text_model.encoder.layers.4.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.5.ln_1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.5.ln_1.weight": "te_text_model.encoder.layers.5.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.5.ln_2.bias": "te_text_model.encoder.layers.5.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.5.ln_2.weight": "te_text_model.encoder.layers.5.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.bias": "te_text_model.encoder.layers.5.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.6.ln_1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.6.ln_1.weight": "te_text_model.encoder.layers.6.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.6.ln_2.bias": "te_text_model.encoder.layers.6.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.6.ln_2.weight": "te_text_model.encoder.layers.6.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.bias": "te_text_model.encoder.layers.6.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.7.ln_1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.7.ln_1.weight": "te_text_model.encoder.layers.7.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.7.ln_2.bias": "te_text_model.encoder.layers.7.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.7.ln_2.weight": "te_text_model.encoder.layers.7.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.bias": "te_text_model.encoder.layers.7.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.8.ln_1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.8.ln_1.weight": "te_text_model.encoder.layers.8.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.8.ln_2.bias": "te_text_model.encoder.layers.8.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.8.ln_2.weight": "te_text_model.encoder.layers.8.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.bias": "te_text_model.encoder.layers.8.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight", + "cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.9.ln_1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias", + "cond_stage_model.model.transformer.resblocks.9.ln_1.weight": "te_text_model.encoder.layers.9.layer_norm1.weight", + "cond_stage_model.model.transformer.resblocks.9.ln_2.bias": "te_text_model.encoder.layers.9.layer_norm2.bias", + "cond_stage_model.model.transformer.resblocks.9.ln_2.weight": "te_text_model.encoder.layers.9.layer_norm2.weight", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.bias": "te_text_model.encoder.layers.9.mlp.fc1.bias", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.weight": "te_text_model.encoder.layers.9.mlp.fc1.weight", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.bias": "te_text_model.encoder.layers.9.mlp.fc2.bias", + "cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.weight": "te_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.1.1.norm.bias": "unet_down_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.1.1.norm.weight": "unet_down_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.1.1.proj_in.bias": "unet_down_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.1.1.proj_in.weight": "unet_down_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.1.1.proj_out.bias": "unet_down_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.1.1.proj_out.weight": "unet_down_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.2.1.norm.bias": "unet_down_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.2.1.norm.weight": "unet_down_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.2.1.proj_in.bias": "unet_down_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.2.1.proj_in.weight": "unet_down_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.2.1.proj_out.bias": "unet_down_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.2.1.proj_out.weight": "unet_down_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.1.norm.bias": "unet_up_blocks.3.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.10.1.norm.weight": "unet_up_blocks.3.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.10.1.proj_in.bias": "unet_up_blocks.3.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.10.1.proj_in.weight": "unet_up_blocks.3.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.10.1.proj_out.bias": "unet_up_blocks.3.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.10.1.proj_out.weight": "unet_up_blocks.3.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.1.norm.bias": "unet_up_blocks.3.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.11.1.norm.weight": "unet_up_blocks.3.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.11.1.proj_in.bias": "unet_up_blocks.3.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.11.1.proj_in.weight": "unet_up_blocks.3.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.11.1.proj_out.bias": "unet_up_blocks.3.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.11.1.proj_out.weight": "unet_up_blocks.3.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.9.1.norm.bias": "unet_up_blocks.3.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.9.1.norm.weight": "unet_up_blocks.3.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.9.1.proj_in.bias": "unet_up_blocks.3.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.9.1.proj_in.weight": "unet_up_blocks.3.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.9.1.proj_out.bias": "unet_up_blocks.3.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.9.1.proj_out.weight": "unet_up_blocks.3.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias", + "2048:, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight", + "2048:, :" + ] + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_sd2_unmatched.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_sd2_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..3814d87e7d37f7a1bc565132baf07a269a30422c --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_sd2_unmatched.json @@ -0,0 +1,200 @@ +{ + "ldm": { + "alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.00466156005859375, + "max": 0.9990234375 + }, + "alphas_cumprod_prev": { + "shape": [ + 1000 + ], + "min": 0.0047149658203125, + "max": 1.0 + }, + "betas": { + "shape": [ + 1000 + ], + "min": 0.0008502006530761719, + "max": 0.01200103759765625 + }, + "cond_stage_model.model.logit_scale": { + "shape": [], + "min": 4.60546875, + "max": 4.60546875 + }, + "cond_stage_model.model.text_projection": { + "shape": [ + 1024, + 1024 + ], + "min": -0.109130859375, + "max": 0.09271240234375 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias": { + "shape": [ + 3072 + ], + "min": -2.525390625, + "max": 2.591796875 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight": { + "shape": [ + 3072, + 1024 + ], + "min": -0.12261962890625, + "max": 0.1258544921875 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias": { + "shape": [ + 1024 + ], + "min": -0.422607421875, + "max": 1.17578125 + }, + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight": { + "shape": [ + 1024, + 1024 + ], + "min": -0.0738525390625, + "max": 0.08673095703125 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias": { + "shape": [ + 1024 + ], + "min": -3.392578125, + "max": 0.90625 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight": { + "shape": [ + 1024 + ], + "min": 0.379638671875, + "max": 2.02734375 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias": { + "shape": [ + 1024 + ], + "min": -0.833984375, + "max": 2.525390625 + }, + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight": { + "shape": [ + 1024 + ], + "min": 1.17578125, + "max": 2.037109375 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias": { + "shape": [ + 4096 + ], + "min": -1.619140625, + "max": 0.5595703125 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight": { + "shape": [ + 4096, + 1024 + ], + "min": -0.08953857421875, + "max": 0.13232421875 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias": { + "shape": [ + 1024 + ], + "min": -1.8662109375, + "max": 0.74658203125 + }, + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight": { + "shape": [ + 1024, + 4096 + ], + "min": -0.12939453125, + "max": 0.1009521484375 + }, + "log_one_minus_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": -7.0703125, + "max": -0.004669189453125 + }, + "model_ema.decay": { + "shape": [], + "min": 1.0, + "max": 1.0 + }, + "model_ema.num_updates": { + "shape": [], + "min": 219996, + "max": 219996 + }, + "posterior_log_variance_clipped": { + "shape": [ + 1000 + ], + "min": -46.0625, + "max": -4.421875 + }, + "posterior_mean_coef1": { + "shape": [ + 1000 + ], + "min": 0.000827789306640625, + "max": 1.0 + }, + "posterior_mean_coef2": { + "shape": [ + 1000 + ], + "min": 0.0, + "max": 0.99560546875 + }, + "posterior_variance": { + "shape": [ + 1000 + ], + "min": 0.0, + "max": 0.01200103759765625 + }, + "sqrt_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.0682373046875, + "max": 0.99951171875 + }, + "sqrt_one_minus_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.0291595458984375, + "max": 0.99755859375 + }, + "sqrt_recip_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 1.0, + "max": 14.6484375 + }, + "sqrt_recipm1_alphas_cumprod": { + "shape": [ + 1000 + ], + "min": 0.0291595458984375, + "max": 14.6171875 + } + }, + "diffusers": {} +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_sdxl.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_sdxl.json new file mode 100644 index 0000000000000000000000000000000000000000..dd3c24475b9a933839567e20990d0910944ba82e --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_sdxl.json @@ -0,0 +1,4154 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "2560:, :" + ] + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..d0b2554ae6e6fd8bd12d660cdb64437132c8b52d --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json @@ -0,0 +1,35 @@ +{ + "ldm": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": { + "shape": [ + 1, + 77 + ], + "min": 0.0, + "max": 76.0 + }, + "conditioner.embedders.1.model.logit_scale": { + "shape": [], + "min": 4.60546875, + "max": 4.60546875 + }, + "conditioner.embedders.1.model.text_projection": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + }, + "diffusers": { + "te1_text_projection.weight": { + "shape": [ + 1280, + 1280 + ], + "min": -0.15966796875, + "max": 0.230712890625 + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_ssd.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_ssd.json new file mode 100644 index 0000000000000000000000000000000000000000..9ad06407be7c6eedb4fcfa06805827bf1d2f6924 --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_ssd.json @@ -0,0 +1,3419 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.1.model.text_projection.weight": "te1_text_projection.weight", + "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "2048:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1024, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "1024:2048, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "2048:, :" + ] + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_ssd_unmatched.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_ssd_unmatched.json new file mode 100644 index 0000000000000000000000000000000000000000..6871c9eb6af0c9a4018e1b0ead6c9fd7c7ee387b --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_ssd_unmatched.json @@ -0,0 +1,21 @@ +{ + "ldm": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": { + "shape": [ + 1, + 77 + ], + "min": 0.0, + "max": 76.0 + }, + "conditioner.embedders.1.model.text_model.embeddings.position_ids": { + "shape": [ + 1, + 77 + ], + "min": 0.0, + "max": 76.0 + } + }, + "diffusers": {} +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/keymaps/stable_diffusion_vega.json b/ai-toolkit/toolkit/keymaps/stable_diffusion_vega.json new file mode 100644 index 0000000000000000000000000000000000000000..4117c201963bea780c16acd720055699b92acf43 --- /dev/null +++ b/ai-toolkit/toolkit/keymaps/stable_diffusion_vega.json @@ -0,0 +1,3039 @@ +{ + "ldm_diffusers_keymap": { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias", + "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight", + "conditioner.embedders.1.model.text_projection.weight": "te1_text_projection.weight", + "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight", + "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight", + "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight", + "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight", + "first_stage_model.quant_conv.bias": "vae_quant_conv.bias", + "first_stage_model.quant_conv.weight": "vae_quant_conv.weight", + "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias", + "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight", + "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias", + "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight", + "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "unet_conv_out.bias", + "model.diffusion_model.out.2.weight": "unet_conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias", + "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight", + "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias", + "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight" + }, + "ldm_diffusers_shape_map": { + "first_stage_model.decoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.decoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.k.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.proj_out.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.q.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ], + "first_stage_model.encoder.mid.attn_1.v.weight": [ + [ + 512, + 512, + 1, + 1 + ], + [ + 512, + 512 + ] + ] + }, + "ldm_diffusers_operator_map": { + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias", + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias" + ] + }, + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": { + "cat": [ + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight", + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight" + ] + } + }, + "diffusers_ldm_operator_map": { + "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias", + "2560:, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "0:1280, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "1280:2560, :" + ] + }, + "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": { + "slice": [ + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight", + "2560:, :" + ] + } + } +} \ No newline at end of file diff --git a/ai-toolkit/toolkit/kohya_lora.py b/ai-toolkit/toolkit/kohya_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..b085748a63d0529de53c3fb62b8d0cbdbe6c8661 --- /dev/null +++ b/ai-toolkit/toolkit/kohya_lora.py @@ -0,0 +1,1221 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +# taken from kohya lora sd scripts + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re + + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False + + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # print("default_forward", self.lora_name, x.size()) + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + + if self.network is None or self.network.sub_prompt_index is None: + return self.default_forward(x) + if not self.regional and not self.use_sub_prompt: + return self.default_forward(x) + + if self.regional: + return self.regional_forward(x) + else: + return self.sub_prompt_forward(x) + + def get_mask_for_x(self, x): + # calculate size from shape of x + if len(x.size()) == 4: + h, w = x.size()[2:4] + area = h * w + else: + area = x.size()[1] + + mask = self.network.mask_dic[area] + if mask is None: + raise ValueError(f"mask is None for resolution {area}") + if len(x.size()) != 4: + mask = torch.reshape(mask, (1, -1, 1)) + return mask + + def regional_forward(self, x): + if "attn2_to_out" in self.lora_name: + return self.to_out_forward(x) + + if self.network.mask_dic is None: # sub_prompt_index >= 3 + return self.default_forward(x) + + # apply mask for LoRA result + lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + mask = self.get_mask_for_x(lx) + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + lx = lx * mask + + x = self.org_forward(x) + x = x + lx + + if "attn2_to_q" in self.lora_name and self.network.is_last_network: + x = self.postp_to_q(x) + + return x + + def postp_to_q(self, x): + # repeat x to num_sub_prompts + has_real_uncond = x.size()[0] // self.network.batch_size == 3 + qc = self.network.batch_size # uncond + qc += self.network.batch_size * self.network.num_sub_prompts # cond + if has_real_uncond: + qc += self.network.batch_size # real_uncond + + query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) + query[: self.network.batch_size] = x[: self.network.batch_size] + + for i in range(self.network.batch_size): + qi = self.network.batch_size + i * self.network.num_sub_prompts + query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] + + if has_real_uncond: + query[-self.network.batch_size :] = x[-self.network.batch_size :] + + # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + return query + + def sub_prompt_forward(self, x): + if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA + return self.org_forward(x) + + emb_idx = self.network.sub_prompt_index + if not self.text_encoder: + emb_idx += self.network.batch_size + + # apply sub prompt of X + lx = x[emb_idx :: self.network.num_sub_prompts] + lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale + + # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + + x = self.org_forward(x) + x[emb_idx :: self.network.num_sub_prompts] += lx + + return x + + def to_out_forward(self, x): + # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + + if self.network.is_last_network: + masks = [None] * self.network.num_sub_prompts + self.network.shared[self.lora_name] = (None, masks) + else: + lx, masks = self.network.shared[self.lora_name] + + # call own LoRA + x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] + lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale + + if self.network.is_last_network: + lx = torch.zeros( + (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype + ) + self.network.shared[self.lora_name] = (lx, masks) + + # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 + masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) + + # if not last network, return x and masks + x = self.org_forward(x) + if not self.network.is_last_network: + return x + + lx, masks = self.network.shared.pop(self.lora_name) + + # if last network, combine separated x with mask weighted sum + has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 + + out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) + out[: self.network.batch_size] = x[: self.network.batch_size] # uncond + if has_real_uncond: + out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond + + # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # for i in range(len(masks)): + # if masks[i] is None: + # masks[i] = torch.zeros_like(masks[-1]) + + mask = torch.cat(masks) + mask_sum = torch.sum(mask, dim=0) + 1e-4 + for i in range(self.network.batch_size): + # 1枚の画像ごとに処理する + lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] + lx1 = lx1 * mask + lx1 = torch.sum(lx1, dim=0) + + xi = self.network.batch_size + i * self.network.num_sub_prompts + x1 = x[xi : xi + self.network.num_sub_prompts] + x1 = x1 * mask + x1 = torch.sum(x1, dim=0) + x1 = x1 / mask_sum + + x1 = x1 + lx1 + out[self.network.batch_size + i] = x1 + + # print("to_out_forward", x.size(), out.size(), has_real_uncond) + return out + + +def parse_block_lr_kwargs(nw_kwargs): + down_lr_weight = nw_kwargs.get("down_lr_weight", None) + mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) + up_lr_weight = nw_kwargs.get("up_lr_weight", None) + + # 以上のいずれにも設定がない場合は無効としてNoneを返す + if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: + return None, None, None + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + return down_lr_weight, mid_lr_weight, up_lr_weight + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + + # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする + if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + ) + + else: + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, + ) + + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + + return network + + +# このメソッドは外部から呼び出される可能性を考慮しておく +# network_dim, network_alpha にはデフォルト値が入っている。 +# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている +# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている +def get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + # block_dimsとblock_alphasをパースする。必ず値が入る + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert ( + len(block_dims) == num_total_blocks + ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + else: + print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" + else: + print( + f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" + else: + if conv_alpha is None: + conv_alpha = 1.0 + print( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + print( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +def get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold +) -> Tuple[List[float], List[float], List[float]]: + # パラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None, None, None + + max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) + base_lr for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + elif name == "zeros": + return [0.0 + base_lr] * max_len + else: + print( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: + down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len: + up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + print("apply block learning rate / 階層別学習率を適用します。") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + else: + print("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: + mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:", mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + else: + print("up_lr_weight: all 1.0, すべて1.0") + + return down_lr_weight, mid_lr_weight, up_lr_weight + + +# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく +def remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight +): + # set 0 to block dim without learning rate to remove the block + if down_lr_weight != None: + for i, lr in enumerate(down_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + if mid_lr_weight != None: + if mid_lr_weight == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if up_lr_weight != None: + for i, lr in enumerate(up_lr_weight): + if lr == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str) -> int: + block_idx = -1 # invalid lora name + + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha[key] = modules_dim[key] + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + + # block lr + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") + + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + print(f"weights are merged") + + # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない + def set_block_lr_weight( + self, + up_lr_weight: List[float] = None, + mid_lr_weight: float = None, + down_lr_weight: List[float] = None, + ): + self.block_lr = True + self.down_lr_weight = down_lr_weight + self.mid_lr_weight = mid_lr_weight + self.up_lr_weight = up_lr_weight + + def get_lr_weight(self, lora: LoRAModule) -> float: + lr_weight = 1.0 + block_idx = get_block_index(lora.lora_name) + if block_idx < 0: + return lr_weight + + if block_idx < LoRANetwork.NUM_OF_BLOCKS: + if self.down_lr_weight != None: + lr_weight = self.down_lr_weight[block_idx] + elif block_idx == LoRANetwork.NUM_OF_BLOCKS: + if self.mid_lr_weight != None: + lr_weight = self.mid_lr_weight + elif block_idx > LoRANetwork.NUM_OF_BLOCKS: + if self.up_lr_weight != None: + lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] + + return lr_weight + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + if self.block_lr: + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + block_idx_to_lora = {} + for lora in self.unet_loras: + idx = get_block_index(lora.lora_name) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # blockごとにパラメータを設定する + for idx, block_loras in block_idx_to_lora.items(): + param_data = {"params": enumerate_params(block_loras)} + + if unet_lr is not None: + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + elif default_lr is not None: + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + all_params.append(param_data) + + else: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + # model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + # metadata["sshs_model_hash"] = model_hash + # metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + if mask.max() == 0: + mask = torch.ones_like(mask) + + self.mask = mask + self.sub_prompt_index = sub_prompt_index + self.is_last_network = is_last_network + + for lora in self.text_encoder_loras + self.unet_loras: + lora.set_network(self) + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + self.batch_size = batch_size + self.num_sub_prompts = num_sub_prompts + self.current_size = (height, width) + self.shared = shared + + # create masks + mask = self.mask + mask_dic = {} + mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w + ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight + dtype = ref_weight.dtype + device = ref_weight.device + + def resize_add(mh, mw): + # print(mh, mw, mh * mw) + m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 + m = m.to(device, dtype=dtype) + mask_dic[mh * mw] = m + + h = height // 8 + w = width // 8 + for _ in range(4): + resize_add(h, w) + if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 + resize_add(h + h % 2, w + w % 2) + h = (h + 1) // 2 + w = (w + 1) // 2 + + self.mask_dic = mask_dic + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/ai-toolkit/toolkit/kohya_model_util.py b/ai-toolkit/toolkit/kohya_model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..798fc2dccb5787973cdf2bbab769459f22b3a805 --- /dev/null +++ b/ai-toolkit/toolkit/kohya_model_util.py @@ -0,0 +1,1533 @@ +# mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py +# I am infinitely grateful to @kohya-ss for their amazing work in this field. +# This version is updated to handle the latest version of the diffusers library. +import json +# v1: split from train_db_fixed.py. +# v2: support safetensors + +import math +import os +import re + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +# DiffUsers版StableDiffusionのモデルパラメータ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 +# UNET_PARAMS_USE_LINEAR_PROJECTION = False + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True + +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" +DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" + + +# region StableDiffusion->Diffusersの変換コード +# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # updated for latest diffusers + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + mapping = {} + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in + range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in + range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in + range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if + f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight" + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias") + mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias" + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっている + # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 + if v2 and not config.get('use_linear_projection', False): + linear_transformer_to_conv(new_checkpoint) + + # print("mapping: ", json.dumps(mapping, indent=4)) + return new_checkpoint + + +# ldm key: diffusers key +vae_ldm_to_diffusers_dict = { + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias", + "decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight", + "decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias", + "decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight", + "decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight", + "decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias", + "decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight", + "decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias", + "decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight", + "decoder.norm_out.bias": "decoder.conv_norm_out.bias", + "decoder.norm_out.weight": "decoder.conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight", + "decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias", + "decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight", + "decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight", + "decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias", + "decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight", + "decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight", + "decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias", + "decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias", + "encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight", + "encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight", + "encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias", + "encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight", + "encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight", + "encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias", + "encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight", + "encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight", + "encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias", + "encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight", + "encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias", + "encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight", + "encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias", + "encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight", + "encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias", + "encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight", + "encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight", + "encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias", + "encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight", + "encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias", + "encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight", + "encoder.norm_out.bias": "encoder.conv_norm_out.bias", + "encoder.norm_out.weight": "encoder.conv_norm_out.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "quant_conv.weight": "quant_conv.weight" +} + + +def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None): + for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): + if i is not None: + ldm_key = ldm_key.replace("{i}", str(i)) + diffusers_key = diffusers_key.replace("{i}", str(i)) + if ldm_key == target_ldm_key: + return diffusers_key + + if ldm_key in vae_ldm_to_diffusers_dict: + return vae_ldm_to_diffusers_dict[ldm_key] + else: + return None + +# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key): +# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): +# if diffusers_key == target_diffusers_key: +# return ldm_key +# return None + +def get_ldm_vae_key_from_diffusers_key(target_diffusers_key): + for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): + if "{" in diffusers_key: # if we have a placeholder + # escape special characters in the key, and replace the placeholder with a regex group + pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)") + match = re.match(pattern, target_diffusers_key) + if match: # if we found a match + return ldm_key.format(i=match.group(1)) + elif diffusers_key == target_diffusers_key: + return ldm_key + return None + + +vae_keys_squished_on_diffusers = [ + "decoder.mid_block.attentions.0.to_k.weight", + "decoder.mid_block.attentions.0.to_out.0.weight", + "decoder.mid_block.attentions.0.to_q.weight", + "decoder.mid_block.attentions.0.to_v.weight", + "encoder.mid_block.attentions.0.to_k.weight", + "encoder.mid_block.attentions.0.to_out.0.weight", + "encoder.mid_block.attentions.0.to_q.weight", + "encoder.mid_block.attentions.0.to_v.weight" +] + +def convert_diffusers_back_to_ldm(diffusers_vae): + new_state_dict = OrderedDict() + diffusers_state_dict = diffusers_vae.state_dict() + for key, value in diffusers_state_dict.items(): + val_to_save = value + if key in vae_keys_squished_on_diffusers: + val_to_save = value.clone() + # (512, 512) diffusers and (512, 512, 1, 1) ldm + val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1) + ldm_key = get_ldm_vae_key_from_diffusers_key(key) + if ldm_key is not None: + new_state_dict[ldm_key] = val_to_save + else: + # for now add current key + new_state_dict[key] = val_to_save + return new_state_dict + + +def convert_ldm_vae_checkpoint(checkpoint, config): + mapping = {} + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + # for key in list(vae_state_dict.keys()): + # diffusers_key = get_diffusers_vae_key_from_ldm_key(key) + # if diffusers_key is not None: + # new_checkpoint[diffusers_key] = vae_state_dict[key] + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in + range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in + range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if + f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + return text_model_dict + + +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + + +# endregion + + +# region Diffusers->StableDiffusion の変換コード +# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + + +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3 - i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i + 1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + + +# endregion + +# region 自作のモデル読み書きなど + + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from):] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, + unet_use_linear_projection_in_v2=False): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loading vae:", info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + + logging.set_verbosity_error() # don't show annoying warning + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + logging.set_verbosity_warning() + + # latest transformers doesnt have position ids. Do we remove it? + if "text_model.embeddings.position_ids" not in text_model.state_dict(): + del converted_text_encoder_checkpoint["text_model.embeddings.position_ids"] + + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) + + return text_model, vae, unet + + +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd + + +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, + vae=None): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, + use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +# endregion + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) + + size += divisible + + resos = list(resos) + resos.sort() + return resos + + +if __name__ == "__main__": + resos = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + aspect_ratios = [w / h for w, h in resos] + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) diff --git a/ai-toolkit/toolkit/layers.py b/ai-toolkit/toolkit/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc975bfb76ee564021c7ea823a3cdba09aeba48 --- /dev/null +++ b/ai-toolkit/toolkit/layers.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.utils.checkpoint import checkpoint + + +class ReductionKernel(nn.Module): + # Tensorflow + def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + super(ReductionKernel, self).__init__() + self.kernel_size = kernel_size + self.in_channels = in_channels + numpy_kernel = self.build_kernel() + self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + + def build_kernel(self): + # tensorflow kernel is (height, width, in_channels, out_channels) + # pytorch kernel is (out_channels, in_channels, height, width) + kernel_size = self.kernel_size + channels = self.in_channels + kernel_shape = [channels, channels, kernel_size, kernel_size] + kernel = np.zeros(kernel_shape, np.float32) + + kernel_value = 1.0 / (kernel_size * kernel_size) + for i in range(0, channels): + kernel[i, i, :, :] = kernel_value + return kernel + + def forward(self, x): + return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1) + + +class CheckpointGradients(nn.Module): + def __init__(self, is_gradient_checkpointing=True): + super(CheckpointGradients, self).__init__() + self.is_gradient_checkpointing = is_gradient_checkpointing + + def forward(self, module, *args, num_chunks=1): + if self.is_gradient_checkpointing: + return checkpoint(module, *args, num_chunks=self.num_chunks) + else: + return module(*args) diff --git a/ai-toolkit/toolkit/llvae.py b/ai-toolkit/toolkit/llvae.py new file mode 100644 index 0000000000000000000000000000000000000000..9d559bfea01676ad9a9c255d930d693099b1a2c9 --- /dev/null +++ b/ai-toolkit/toolkit/llvae.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import numpy as np +import itertools + + +class LosslessLatentDecoder(nn.Module): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): + super(LosslessLatentDecoder, self).__init__() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.latent_depth = latent_depth + self.in_channels = in_channels + self.out_channels = int(in_channels // (latent_depth * latent_depth)) + numpy_kernel = self.build_kernel(in_channels, latent_depth) + numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + if trainable: + self.kernel = nn.Parameter(numpy_kernel) + else: + self.kernel = numpy_kernel + + def build_kernel(self, in_channels, latent_depth): + # my old code from tensorflow. + # tensorflow kernel is (height, width, out_channels, in_channels) + # pytorch kernel is (in_channels, out_channels, height, width) + out_channels = self.out_channels + + # kernel_shape = [kernel_filter_size, kernel_filter_size, out_channels, in_channels] # tensorflow + kernel_shape = [in_channels, out_channels, latent_depth, latent_depth] # pytorch + kernel = np.zeros(kernel_shape, np.float32) + + # Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel. + for c in range(0, out_channels): + i = 0 + for x, y in itertools.product(range(latent_depth), repeat=2): + # kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow + kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch + i += 1 + + return kernel + + def forward(self, x): + dtype = x.dtype + if self.kernel.dtype != dtype: + self.kernel = self.kernel.to(dtype=dtype) + + # Deconvolve input tensor with the kernel + return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1) + + +class LosslessLatentEncoder(nn.Module): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): + super(LosslessLatentEncoder, self).__init__() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.latent_depth = latent_depth + self.in_channels = in_channels + self.out_channels = int(in_channels * (latent_depth * latent_depth)) + numpy_kernel = self.build_kernel(in_channels, latent_depth) + numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + if trainable: + self.kernel = nn.Parameter(numpy_kernel) + else: + self.kernel = numpy_kernel + + + def build_kernel(self, in_channels, latent_depth): + # my old code from tensorflow. + # tensorflow kernel is (height, width, in_channels, out_channels) + # pytorch kernel is (out_channels, in_channels, height, width) + out_channels = self.out_channels + + # kernel_shape = [latent_depth, latent_depth, in_channels, out_channels] # tensorflow + kernel_shape = [out_channels, in_channels, latent_depth, latent_depth] # pytorch + kernel = np.zeros(kernel_shape, np.float32) + + # Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel. + for c in range(0, in_channels): + i = 0 + for x, y in itertools.product(range(latent_depth), repeat=2): + # kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow + kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch + i += 1 + return kernel + + def forward(self, x): + dtype = x.dtype + if self.kernel.dtype != dtype: + self.kernel = self.kernel.to(dtype=dtype) + # Convolve input tensor with the kernel + return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1) + + +class LosslessLatentVAE(nn.Module): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): + super(LosslessLatentVAE, self).__init__() + self.latent_depth = latent_depth + self.in_channels = in_channels + self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable) + encoder_out_channels = self.encoder.out_channels + self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable) + + def forward(self, x): + latent = self.latent_encoder(x) + out = self.latent_decoder(latent) + return out + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.decoder(x) + + +# test it +if __name__ == '__main__': + import os + from PIL import Image + import torchvision.transforms as transforms + user_path = os.path.expanduser('~') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + input_path = os.path.join(user_path, "Pictures/sample_2_512.png") + output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png") + img = Image.open(input_path) + img_tensor = transforms.ToTensor()(img) + img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) + print("input_shape: ", list(img_tensor.shape)) + vae = LosslessLatentVAE(in_channels=3, latent_depth=8, dtype=dtype).to(device=device, dtype=dtype) + latent = vae.encode(img_tensor) + print("latent_shape: ", list(latent.shape)) + out_tensor = vae.decode(latent) + print("out_shape: ", list(out_tensor.shape)) + + mse_loss = nn.MSELoss() + mse = mse_loss(img_tensor, out_tensor) + print("roundtrip_loss: ", mse.item()) + out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) + out_img.save(output_path) diff --git a/ai-toolkit/toolkit/logging_aitk.py b/ai-toolkit/toolkit/logging_aitk.py new file mode 100644 index 0000000000000000000000000000000000000000..587b0ecf9cff466bf6da70195b0b88fe31a31cdc --- /dev/null +++ b/ai-toolkit/toolkit/logging_aitk.py @@ -0,0 +1,344 @@ +from typing import OrderedDict, Optional +from PIL import Image + +from toolkit.config_modules import LoggingConfig +import os +import sqlite3 +import time +from typing import Any, Dict, Tuple, List + + +# Base logger class +# This class does nothing, it's just a placeholder +class EmptyLogger: + def __init__(self, *args, **kwargs) -> None: + pass + + # start logging the training + def start(self): + pass + + # collect the log to send + def log(self, *args, **kwargs): + pass + + # send the log + def commit(self, step: Optional[int] = None): + pass + + # log image + def log_image(self, *args, **kwargs): + pass + + # finish logging + def finish(self): + pass + + +# Wandb logger class +# This class logs the data to wandb +class WandbLogger(EmptyLogger): + def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None: + self.project = project + self.run_name = run_name + self.config = config + + def start(self): + try: + import wandb + except ImportError: + raise ImportError( + "Failed to import wandb. Please install wandb by running `pip install wandb`" + ) + + # send the whole config to wandb + run = wandb.init(project=self.project, name=self.run_name, config=self.config) + self.run = run + self._log = wandb.log # log function + self._image = wandb.Image # image object + + def log(self, *args, **kwargs): + # when commit is False, wandb increments the step, + # but we don't want that to happen, so we set commit=False + self._log(*args, **kwargs, commit=False) + + def commit(self, step: Optional[int] = None): + # after overall one step is done, we commit the log + # by log empty object with commit=True + self._log({}, step=step, commit=True) + + def log_image( + self, + image: Image, + id, # sample index + caption: str | None = None, # positive prompt + *args, + **kwargs, + ): + # create a wandb image object and log it + image = self._image(image, caption=caption, *args, **kwargs) + self._log({f"sample_{id}": image}, commit=False) + + def finish(self): + self.run.finish() + + +class UILogger: + def __init__( + self, + log_file: str, + flush_every_n: int = 256, + flush_every_secs: float = 0.25, + ) -> None: + self.log_file = log_file + self._log_to_commit: Dict[str, Any] = {} + + self._con: Optional[sqlite3.Connection] = None + self._started = False + + self._step_counter = 0 + + # buffered writes + self._pending_steps: List[Tuple[int, float]] = [] + self._pending_metrics: List[ + Tuple[int, str, Optional[float], Optional[str]] + ] = [] + self._pending_key_minmax: Dict[str, Tuple[int, int]] = {} + + self._flush_every_n = int(flush_every_n) + self._flush_every_secs = float(flush_every_secs) + self._last_flush = time.time() + + self._first_commit_done = False + + # start logging the training + def start(self): + if self._started: + return + + parent = os.path.dirname(os.path.abspath(self.log_file)) + if parent and not os.path.exists(parent): + os.makedirs(parent, exist_ok=True) + + self._con = sqlite3.connect(self.log_file, timeout=30.0, isolation_level=None) + self._con.execute("PRAGMA journal_mode=WAL;") + self._con.execute("PRAGMA synchronous=NORMAL;") + self._con.execute("PRAGMA temp_store=MEMORY;") + self._con.execute("PRAGMA foreign_keys=ON;") + self._con.execute("PRAGMA busy_timeout=30000;") + + self._init_schema(self._con) + + self._started = True + self._last_flush = time.time() + + # collect the log to send + def log(self, log_dict): + # log_dict is like {'learning_rate': learning_rate} + if not isinstance(log_dict, dict): + raise TypeError("log_dict must be a dict") + self._log_to_commit.update(log_dict) + + # send the log + def commit(self, step: Optional[int] = None): + if not self._started: + self.start() + + if not self._log_to_commit: + return + + if step is None: + step = self._step_counter + self._step_counter += 1 + else: + step = int(step) + if step >= self._step_counter: + self._step_counter = step + 1 + + # On the first commit of this run, prune any rows from a prior run + # whose step is greater than where we are resuming from. + if not self._first_commit_done: + self._prune_future_steps(step) + self._first_commit_done = True + + wall_time = time.time() + + # buffer step row (upsert later) + self._pending_steps.append((step, wall_time)) + + # buffer metrics rows + key min/max updates + for k, v in self._log_to_commit.items(): + k = k if isinstance(k, str) else str(k) + vr, vt = self._coerce_value(v) + + self._pending_metrics.append((step, k, vr, vt)) + + if k in self._pending_key_minmax: + lo, hi = self._pending_key_minmax[k] + if step < lo: + lo = step + if step > hi: + hi = step + self._pending_key_minmax[k] = (lo, hi) + else: + self._pending_key_minmax[k] = (step, step) + + self._log_to_commit = {} + + # flush conditions + now = time.time() + if ( + len(self._pending_metrics) >= self._flush_every_n + or (now - self._last_flush) >= self._flush_every_secs + ): + self._flush() + + # log image + def log_image(self, *args, **kwargs): + # this doesnt log images for now + pass + + # finish logging + def finish(self): + if not self._started: + return + + self._flush() + + assert self._con is not None + self._con.close() + self._con = None + self._started = False + + # ------------------------- + # internal + # ------------------------- + + def _init_schema(self, con: sqlite3.Connection) -> None: + con.execute("BEGIN;") + + con.execute(""" + CREATE TABLE IF NOT EXISTS steps ( + step INTEGER PRIMARY KEY, + wall_time REAL NOT NULL + ); + """) + + con.execute(""" + CREATE TABLE IF NOT EXISTS metric_keys ( + key TEXT PRIMARY KEY, + first_seen_step INTEGER, + last_seen_step INTEGER + ); + """) + + con.execute(""" + CREATE TABLE IF NOT EXISTS metrics ( + step INTEGER NOT NULL, + key TEXT NOT NULL, + value_real REAL, + value_text TEXT, + PRIMARY KEY (step, key), + FOREIGN KEY (step) REFERENCES steps(step) ON DELETE CASCADE + ); + """) + + con.execute( + "CREATE INDEX IF NOT EXISTS idx_metrics_key_step ON metrics (key, step);" + ) + + con.execute("COMMIT;") + + def _coerce_value(self, v: Any) -> Tuple[Optional[float], Optional[str]]: + if v is None: + return None, None + if isinstance(v, bool): + return float(int(v)), None + if isinstance(v, (int, float)): + return float(v), None + try: + return float(v), None # type: ignore[arg-type] + except Exception: + return None, str(v) + + def _prune_future_steps(self, current_step: int) -> None: + assert self._con is not None + con = self._con + + con.execute("BEGIN;") + # metrics rows cascade via FK ON DELETE CASCADE + con.execute("DELETE FROM steps WHERE step > ?;", (current_step,)) + # drop any keys that no longer have any metrics, and clamp last_seen_step + con.execute( + "DELETE FROM metric_keys " + "WHERE NOT EXISTS (SELECT 1 FROM metrics WHERE metrics.key = metric_keys.key);" + ) + con.execute( + "UPDATE metric_keys " + "SET last_seen_step = (SELECT MAX(step) FROM metrics WHERE metrics.key = metric_keys.key) " + "WHERE last_seen_step > ?;", + (current_step,), + ) + con.execute("COMMIT;") + + def _flush(self) -> None: + if not self._pending_steps and not self._pending_metrics: + return + + assert self._con is not None + con = self._con + + con.execute("BEGIN;") + + # steps upsert + if self._pending_steps: + con.executemany( + "INSERT INTO steps(step, wall_time) VALUES(?, ?) " + "ON CONFLICT(step) DO UPDATE SET wall_time=excluded.wall_time;", + self._pending_steps, + ) + + # keys table upsert (maintains list of keys + seen range) + if self._pending_key_minmax: + con.executemany( + "INSERT INTO metric_keys(key, first_seen_step, last_seen_step) VALUES(?, ?, ?) " + "ON CONFLICT(key) DO UPDATE SET " + "first_seen_step=MIN(metric_keys.first_seen_step, excluded.first_seen_step), " + "last_seen_step=MAX(metric_keys.last_seen_step, excluded.last_seen_step);", + [(k, lo, hi) for k, (lo, hi) in self._pending_key_minmax.items()], + ) + + # metrics upsert + if self._pending_metrics: + con.executemany( + "INSERT INTO metrics(step, key, value_real, value_text) VALUES(?, ?, ?, ?) " + "ON CONFLICT(step, key) DO UPDATE SET " + "value_real=excluded.value_real, value_text=excluded.value_text;", + self._pending_metrics, + ) + + con.execute("COMMIT;") + + self._pending_steps.clear() + self._pending_metrics.clear() + self._pending_key_minmax.clear() + self._last_flush = time.time() + + +# create logger based on the logging config +def create_logger( + logging_config: LoggingConfig, + all_config: OrderedDict, + save_root: Optional[str] = None, +): + if logging_config.use_wandb: + project_name = logging_config.project_name + run_name = logging_config.run_name + return WandbLogger(project=project_name, run_name=run_name, config=all_config) + elif logging_config.use_ui_logger: + if save_root is None: + raise ValueError("save_root must be provided when using UILogger") + log_file = os.path.join(save_root, "loss_log.db") + return UILogger(log_file=log_file) + else: + return EmptyLogger() diff --git a/ai-toolkit/toolkit/lora_special.py b/ai-toolkit/toolkit/lora_special.py new file mode 100644 index 0000000000000000000000000000000000000000..57c107b452ae2d138c58bc2adc720776f404f1ae --- /dev/null +++ b/ai-toolkit/toolkit/lora_special.py @@ -0,0 +1,763 @@ +import copy +import json +import math +import weakref +import os +import re +import sys +from typing import List, Optional, Dict, Type, Union +import torch +from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel +from transformers import CLIPTextModel +from toolkit.models.lokr import LokrModule + +from .config_modules import NetworkConfig +from .lorm import count_parameters +from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin + +from toolkit.kohya_lora import LoRANetwork +from toolkit.models.DoRA import DoRAModule +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear', + 'QLinear', + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv', + 'QConv2d', +] + +class IdentityModule(torch.nn.Module): + def forward(self, x): + return x + +class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + network: 'LoRASpecialNetwork' = None, + use_bias: bool = False, + is_ara: bool = False, + **kwargs + ): + self.can_merge_in = True + """if alpha == 0 or None, alpha is rank (no scaling).""" + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.orig_module_ref = weakref.ref(org_module) + self.scalar = torch.tensor(1.0, device=org_module.weight.device) + + # if is ara lora module, mark it on the layer so memory manager can handle it + if is_ara: + org_module.ara_lora_ref = weakref.ref(self) + # check if parent has bias. if not force use_bias to False + if org_module.bias is None: + use_bias = False + + if org_module.__class__.__name__ in CONV_MODULES: + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + self.full_rank = network.network_type.lower() == "fullrank" + + if org_module.__class__.__name__ in CONV_MODULES: + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + if self.full_rank: + self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False) + self.lora_up = IdentityModule() + else: + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) + else: + if self.full_rank: + self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False) + self.lora_up = IdentityModule() + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + if not self.full_rank: + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier: Union[float, List[float]] = multiplier + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module + + +def _is_quantized_tensor(t) -> bool: + # torchao stores quantized weights as tensor subclasses (e.g. AffineQuantizedTensor) under torchao.* + # that are still nn.Parameter instances and expose .dequantize(). (quanto is intentionally not handled.) + return 'torchao' in type(t).__module__ and hasattr(t, 'dequantize') + + +def _dequantize_if_needed(t): + return t.dequantize() if _is_quantized_tensor(t) else t + + +class FullModule(ToolkitModuleMixin, torch.nn.Module): + """ + Full weight "lora" for layers that have no sensible low rank decomposition (norm layers, embeddings, + stray biases, etc). It does not have an up/down projection. It holds a trainable delta that is added to + the original weight (and bias) of the wrapped module. On save it emits `.diff` (and `.diff_b` + for bias) which ComfyUI applies as `weight += strength * diff`, so it merges directly into the model + weights without any extra adapter. + + If the wrapped module's weight is torchao-quantized, the delta is kept in full precision and the original + weight is dequantized on the fly in the forward pass (the original quantized tensor is left untouched). + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + network: 'LoRASpecialNetwork' = None, + **kwargs + ): + self.can_merge_in = True + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + # keep the original module out of our state_dict (list hides it from nn.Module registration) + self.org_module = [org_module] + self.orig_module_ref = weakref.ref(org_module) + self.multiplier: Union[float, List[float]] = multiplier + # these are unused for full modules but the mixin/forward path expects them to exist + self.dropout = None + self.rank_dropout = None + self.module_dropout = None + self.is_checkpointing = False + + # trainable delta, zero initialized so an untrained layer is a no-op (zero diff) + # dequantize first so the delta is full precision and shaped like the real (unpacked) weight + self.weight_is_quantized = _is_quantized_tensor(org_module.weight) + ref_weight = _dequantize_if_needed(org_module.weight) + self.diff = torch.nn.Parameter(torch.zeros_like(ref_weight)) + # some modules (e.g. Embedding) have no bias attribute at all + org_bias = getattr(org_module, 'bias', None) + if org_bias is not None: + self.diff_b = torch.nn.Parameter(torch.zeros_like(org_bias)) + else: + self.diff_b = None + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def forward(self, x, *args, **kwargs): + network: 'LoRASpecialNetwork' = self.network_ref() + skip = (not network.is_active) or network.is_merged_in or network._multiplier == 0 or network.is_lorm + if skip: + return self.org_forward(x, *args, **kwargs) + + om = self.org_module[0] + multiplier = network.torch_multiplier + # weight space application can't be done per sample, so use the mean (same as the DoRA path) + mult = multiplier.mean() if isinstance(multiplier, torch.Tensor) else multiplier + + orig_weight = om._parameters['weight'] + # dequantize quantized weights to full precision so the delta can be added (the original + # quantized tensor is restored in the finally block below) + base_weight = _dequantize_if_needed(orig_weight) + eff_weight = base_weight + (self.diff.to(base_weight.device) * mult).to(base_weight.dtype) + + has_bias = self.diff_b is not None and om._parameters.get('bias', None) is not None + if has_bias: + orig_bias = om._parameters['bias'] + eff_bias = orig_bias + (self.diff_b.to(orig_bias.device) * mult).to(orig_bias.dtype) + + # temporarily swap in the effective weights so the original forward (norm/linear/etc) uses them. + # this keeps autograd flowing into our delta while supporting any layer type. + om._parameters['weight'] = eff_weight + if has_bias: + om._parameters['bias'] = eff_bias + try: + out = self.org_forward(x, *args, **kwargs) + finally: + om._parameters['weight'] = orig_weight + if has_bias: + om._parameters['bias'] = orig_bias + return out + + @torch.no_grad() + def merge_in(self: 'FullModule', merge_weight=1.0): + if not self.can_merge_in: + return + om = self.org_module[0] + if 'weight._data' in om.state_dict(): + # quanto quantized weight, can't merge + return + org_weight = om.weight + orig_dtype = org_weight.dtype + # dequantize torchao weights so we can fold the full precision delta in + merged_weight = _dequantize_if_needed(org_weight).float() + merge_weight * self.diff.float().to(org_weight.device) + if self.weight_is_quantized: + # re-quantize so the model stays quantized across continuous merge/reset cycles + from toolkit.util.quantize import get_torchao_config, requantize_module_weight + requantize_module_weight(om, merged_weight, orig_dtype, get_torchao_config(self._get_base_qtype())) + else: + om.weight.data = merged_weight.to(org_weight.device, orig_dtype) + # bias is never quantized + if self.diff_b is not None and getattr(om, 'bias', None) is not None: + om.bias.data = (om.bias.data.float() + merge_weight * self.diff_b.float().to(om.bias.device)).to(om.bias.dtype) + + def reset_weights(self: 'FullModule'): + with torch.no_grad(): + self.diff.zero_() + if self.diff_b is not None: + self.diff_b.zero_() + + +class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + + # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"] + UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"] + # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + PEFT_PREFIX_UNET = "unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + train_text_encoder: Optional[bool] = True, + use_text_encoder_1: bool = True, + use_text_encoder_2: bool = True, + train_unet: Optional[bool] = True, + is_sdxl=False, + is_v2=False, + is_v3=False, + is_pixart: bool = False, + is_auraflow: bool = False, + is_flux: bool = False, + is_lumina2: bool = False, + use_bias: bool = False, + is_lorm: bool = False, + ignore_if_contains = None, + only_if_contains = None, + full_if_contains = None, + parameter_threshold: float = 0.0, + attn_only: bool = False, + target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, + target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, + network_type: str = "lora", + full_train_in_out: bool = False, + transformer_only: bool = False, + peft_format: bool = False, + is_assistant_adapter: bool = False, + is_transformer: bool = False, + base_model: 'StableDiffusion' = None, + is_ara: bool = False, + **kwargs + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + # call the parent of the parent we are replacing (LoRANetwork) init + torch.nn.Module.__init__(self) + ToolkitNetworkMixin.__init__( + self, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + is_sdxl=is_sdxl, + is_v2=is_v2, + is_lorm=is_lorm, + **kwargs + ) + if ignore_if_contains is None: + ignore_if_contains = [] + self.ignore_if_contains = ignore_if_contains + # full_if_contains: any layer (even linear/conv) whose name matches becomes a full weight + # module instead of a normal lora module + if full_if_contains is None: + full_if_contains = [] + elif isinstance(full_if_contains, str): + full_if_contains = [full_if_contains] + self.full_if_contains = full_if_contains + self.transformer_only = transformer_only + self.base_model_ref = None + if base_model is not None: + self.base_model_ref = weakref.ref(base_model) + + self.only_if_contains: Union[List, None] = only_if_contains + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self.torch_multiplier = None + # triggers the state updates + self.multiplier = multiplier + self.is_sdxl = is_sdxl + self.is_v2 = is_v2 + self.is_v3 = is_v3 + self.is_pixart = is_pixart + self.is_auraflow = is_auraflow + self.is_flux = is_flux + self.is_lumina2 = is_lumina2 + self.network_type = network_type + self.is_assistant_adapter = is_assistant_adapter + self.full_rank = network_type.lower() == "fullrank" + self.is_ara = is_ara + if self.network_type.lower() == "dora": + self.module_class = DoRAModule + module_class = DoRAModule + elif self.network_type.lower() == "lokr": + self.module_class = LokrModule + module_class = LokrModule + self.network_config: NetworkConfig = kwargs.get("network_config", None) + + self.peft_format = peft_format + self.is_transformer = is_transformer + + # use the old format for older models unless the user has specified otherwise + self.use_old_lokr_format = False + if self.network_config is not None and hasattr(self.network_config, 'old_lokr_format'): + self.use_old_lokr_format = self.network_config.old_lokr_format + # also allow a false from the model itself + if base_model is not None and not base_model.use_old_lokr_format: + self.use_old_lokr_format = False + + # always do peft for flux only for now + if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer: + # don't do peft format for lokr if using old format + if self.network_type.lower() != "lokr" or not self.use_old_lokr_format: + self.peft_format = True + + if self.peft_format: + # no alpha for peft + self.alpha = self.lora_dim + alpha = self.alpha + self.conv_alpha = self.conv_lora_dim + conv_alpha = self.conv_alpha + + self.full_train_in_out = full_train_in_out + + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + unet_prefix = self.LORA_PREFIX_UNET + if self.peft_format: + unet_prefix = self.PEFT_PREFIX_UNET + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: + unet_prefix = f"lora_transformer" + if self.peft_format: + unet_prefix = "transformer" + + prefix = ( + unet_prefix + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + attached_modules = [] + lora_shape_dict = {} + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ in LINEAR_MODULES + is_conv2d = child_module.__class__.__name__ in CONV_MODULES + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + lora_name = [prefix, name, child_name] + # filter out blank + lora_name = [x for x in lora_name if x and x != ""] + lora_name = ".".join(lora_name) + # if it doesnt have a name, it wil have two dots + lora_name.replace("..", ".") + clean_name = lora_name + if self.peft_format: + # we replace this on saving + lora_name = lora_name.replace(".", "$$") + else: + lora_name = lora_name.replace(".", "_") + + # decide if this should be a full weight module instead of a normal lora. + # - all_layers: every remaining weight bearing leaf that isn't linear/conv + # (norm layers, embeddings, stray biases, etc) + # - full_if_contains: any matching layer, INCLUDING linear/conv, overriding the + # normal lora for it + all_layers = self.network_config is not None and getattr(self.network_config, 'all_layers', False) + is_leaf_with_weight = ( + len(list(child_module.children())) == 0 + and isinstance(getattr(child_module, 'weight', None), torch.nn.Parameter) + ) + matches_full_if_contains = len(self.full_if_contains) > 0 and ( + any([word in clean_name for word in self.full_if_contains]) + or any([word in lora_name for word in self.full_if_contains]) + ) + is_full_layer = is_leaf_with_weight and ( + matches_full_if_contains + or (all_layers and not is_linear and not is_conv2d) + ) + + skip = False + if any([word in clean_name for word in self.ignore_if_contains]): + skip = True + + # see if it is over threshold + if count_parameters(child_module) < parameter_threshold: + skip = True + + if self.transformer_only and is_unet: + transformer_block_names = None + if base_model is not None: + transformer_block_names = base_model.get_transformer_block_names() + + if transformer_block_names is not None: + # match against clean_name (dotted) so block names can be + # dotted paths (e.g. "model.language_model.layers"); lora_name + # has dots replaced with "$$"/"_" and wouldn't match. + if not any([block_name in clean_name for block_name in transformer_block_names]): + skip = True + else: + if self.is_pixart: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_flux: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_lumina2: + if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: + skip = True + if self.is_v3: + if "transformer_blocks" not in lora_name: + skip = True + + # handle custom models + if hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'blocks'): + if "blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'single_blocks'): + if "single_blocks" not in lora_name and "double_blocks" not in lora_name: + skip = True + + if (is_linear or is_conv2d) and not skip and not is_full_layer: + + if self.only_if_contains is not None: + if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or ( + self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + module_kwargs = {} + + if self.network_type.lower() == "lokr": + module_kwargs["factor"] = self.network_config.lokr_factor + + if self.is_ara: + module_kwargs["is_ara"] = True + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + network=self, + parent=module, + use_bias=use_bias, + **module_kwargs + ) + loras.append(lora) + if self.network_type.lower() == "lokr": + try: + lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] + except: + pass + else: + if self.full_rank: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + else: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] + + elif is_full_layer and not skip: + if self.only_if_contains is not None: + if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): + continue + + lora = FullModule( + lora_name, + child_module, + self.multiplier, + network=self, + parent=module, + ) + loras.append(lora) + lora_shape_dict[lora_name] = [list(lora.diff.shape)] + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + if train_text_encoder: + for i, text_encoder in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + + if self.is_pixart: + replace_modules = ["T5EncoderModel"] + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = target_lin_modules + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += target_conv_modules + + if is_v3: + target_modules = ["SD3Transformer2DModel"] + + if is_pixart: + target_modules = ["PixArtTransformer2DModel"] + + if is_auraflow: + target_modules = ["AuraFlowTransformer2DModel"] + + if is_flux: + target_modules = ["FluxTransformer2DModel"] + + if is_lumina2: + target_modules = ["Lumina2Transformer2DModel"] + + if train_unet: + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + else: + self.unet_loras = [] + skipped_un = [] + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") + + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + if self.full_train_in_out: + print("full train in out") + # we are going to retrain the main in out layers for VAE change usually + if self.is_pixart: + transformer: PixArtTransformer2DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.pos_embed = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out + + elif self.is_auraflow: + transformer: AuraFlowTransformer2DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.pos_embed = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out + + elif base_model is not None and base_model.arch == "wan21": + transformer: WanTransformer3DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.patch_embedding = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out + + else: + unet: UNet2DConditionModel = unet + unet_conv_in: torch.nn.Conv2d = unet.conv_in + unet_conv_out: torch.nn.Conv2d = unet.conv_out + + # clone these and replace their forwards with ours + self.unet_conv_in = copy.deepcopy(unet_conv_in) + self.unet_conv_out = copy.deepcopy(unet_conv_out) + unet.conv_in = self.unet_conv_in + unet.conv_out = self.unet_conv_out + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # call Lora prepare_optimizer_params + all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) + + if self.full_train_in_out: + base_model = self.base_model_ref() if self.base_model_ref is not None else None + if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"): + all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) + else: + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())}) + + return all_params + + diff --git a/ai-toolkit/toolkit/lorm.py b/ai-toolkit/toolkit/lorm.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfdb516be12d98e464a7d9a96bc4e19b83b9f91 --- /dev/null +++ b/ai-toolkit/toolkit/lorm.py @@ -0,0 +1,461 @@ +from typing import Union, Tuple, Literal, Optional + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from torch import Tensor +from tqdm import tqdm + +from toolkit.config_modules import LoRMConfig + +conv = nn.Conv2d +lin = nn.Linear +_size_2_t = Union[int, Tuple[int, int]] + +ExtractMode = Union[ + 'fixed', + 'threshold', + 'ratio', + 'quantile', + 'percentage' +] + +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' +] +CONV_MODULES = [ + # 'Conv2d', + # 'LoRACompatibleConv' +] + +UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + # "ResnetBlock2D", + "Downsample2D", + "Upsample2D", +] + +LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE + +UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", +] + +UNET_MODULES_TO_AVOID = [ +] + + +# Low Rank Convolution +class LoRMCon2d(nn.Module): + def __init__( + self, + in_channels: int, + lorm_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 'same', + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_channels = in_channels + self.lorm_channels = lorm_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + + self.down = nn.Conv2d( + in_channels=in_channels, + out_channels=lorm_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + padding_mode=padding_mode, + device=device, + dtype=dtype + ) + + # Kernel size on the up is always 1x1. + # I don't think you could calculate a dual 3x3, or I can't at least + + self.up = nn.Conv2d( + in_channels=lorm_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=1, + padding='same', + dilation=1, + groups=1, + bias=bias, + padding_mode='zeros', + device=device, + dtype=dtype + ) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + x = input + x = self.down(x) + x = self.up(x) + return x + + +class LoRMLinear(nn.Module): + def __init__( + self, + in_features: int, + lorm_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_features = in_features + self.lorm_features = lorm_features + self.out_features = out_features + + self.down = nn.Linear( + in_features=in_features, + out_features=lorm_features, + bias=False, + device=device, + dtype=dtype + + ) + self.up = nn.Linear( + in_features=lorm_features, + out_features=out_features, + bias=bias, + # bias=True, + device=device, + dtype=dtype + ) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + x = input + x = self.down(x) + x = self.up(x) + return x + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu' +) -> Tuple[Tensor, Tensor, int, Tensor]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1)) + if mode == 'percentage': + assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. + original_params = out_ch * in_ch * kernel_size * kernel_size + desired_params = mode_param * original_params + # Solve for lora_rank from the equation + lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch)) + elif mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param).item() + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s).item() + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum).item() + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + lora_rank = int(out_ch / 2) + print(f"rank is higher than it should be") + # print(f"Skipping layer as determined rank is too high") + # return None, None, None, None + # return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return extract_weight_A, extract_weight_B, lora_rank, diff + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', +) -> Tuple[Tensor, Tensor, int, Tensor]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = torch.linalg.svd(weight) + + if mode == 'percentage': + assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. + desired_params = mode_param * out_ch * in_ch + # Solve for lora_rank from the equation + lora_rank = int(desired_params / (in_ch + out_ch)) + elif mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param).item() + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s).item() + elif mode == 'quantile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum).item() + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + # print(f"rank is higher than it should be") + lora_rank = int(out_ch / 2) + # return weight, 'full' + # print(f"Skipping layer as determined rank is too high") + # return None, None, None, None + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return extract_weight_A, extract_weight_B, lora_rank, diff + + +def replace_module_by_path(network, name, module): + """Replace a module in a network by its name.""" + name_parts = name.split('.') + current_module = network + for part in name_parts[:-1]: + current_module = getattr(current_module, part) + try: + setattr(current_module, name_parts[-1], module) + except Exception as e: + print(e) + + +def count_parameters(module): + return sum(p.numel() for p in module.parameters()) + + +def compute_optimal_bias(original_module, linear_down, linear_up, X): + Y_original = original_module(X) + Y_approx = linear_up(linear_down(X)) + E = Y_original - Y_approx + + optimal_bias = E.mean(dim=0) + + return optimal_bias + + +def format_with_commas(n): + return f"{n:,}" + + +def print_lorm_extract_details( + start_num_params: int, + end_num_params: int, + num_replaced: int, +): + start_formatted = format_with_commas(start_num_params) + end_formatted = format_with_commas(end_num_params) + num_replaced_formatted = format_with_commas(num_replaced) + + width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) + + print(f"Convert UNet result:") + print(f" - converted: {num_replaced:>{width},} modules") + print(f" - start: {start_num_params:>{width},} params") + print(f" - end: {end_num_params:>{width},} params") + + +lorm_ignore_if_contains = [ + 'proj_out', 'proj_in', +] + +lorm_parameter_threshold = 1000000 + + +@torch.no_grad() +def convert_diffusers_unet_to_lorm( + unet: UNet2DConditionModel, + config: LoRMConfig, +): + print('Converting UNet to LoRM UNet') + start_num_params = count_parameters(unet) + named_modules = list(unet.named_modules()) + + num_replaced = 0 + + pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet") + layer_names_replaced = [] + converted_modules = [] + ignore_if_contains = [ + 'proj_out', 'proj_in', + ] + + for name, module in named_modules: + module_name = module.__class__.__name__ + if module_name in UNET_TARGET_REPLACE_MODULE: + for child_name, child_module in module.named_modules(): + new_module: Union[LoRMCon2d, LoRMLinear, None] = None + # if child name includes attn, skip it + combined_name = combined_name = f"{name}.{child_name}" + # if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None: + # pass + + lorm_config = config.get_config_for_module(combined_name) + + extract_mode = lorm_config.extract_mode + extract_mode_param = lorm_config.extract_mode_param + parameter_threshold = lorm_config.parameter_threshold + + if any([word in child_name for word in ignore_if_contains]): + pass + + elif child_module.__class__.__name__ in LINEAR_MODULES: + if count_parameters(child_module) > parameter_threshold: + + # dtype = child_module.weight.dtype + dtype = torch.float32 + # extract and convert + down_weight, up_weight, lora_dim, diff = extract_linear( + weight=child_module.weight.clone().detach().float(), + mode=extract_mode, + mode_param=extract_mode_param, + device=child_module.weight.device, + ) + if down_weight is None: + continue + down_weight = down_weight.to(dtype=dtype) + up_weight = up_weight.to(dtype=dtype) + bias_weight = None + if child_module.bias is not None: + bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) + # linear layer weights = (out_features, in_features) + new_module = LoRMLinear( + in_features=down_weight.shape[1], + lorm_features=lora_dim, + out_features=up_weight.shape[0], + bias=bias_weight is not None, + device=down_weight.device, + dtype=down_weight.dtype + ) + + # replace the weights + new_module.down.weight.data = down_weight + new_module.up.weight.data = up_weight + if bias_weight is not None: + new_module.up.bias.data = bias_weight + # else: + # new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data) + + # bias_correction = compute_optimal_bias( + # child_module, + # new_module.down, + # new_module.up, + # torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype) + # ) + # new_module.up.bias.data += bias_correction + + elif child_module.__class__.__name__ in CONV_MODULES: + if count_parameters(child_module) > parameter_threshold: + dtype = child_module.weight.dtype + down_weight, up_weight, lora_dim, diff = extract_conv( + weight=child_module.weight.clone().detach().float(), + mode=extract_mode, + mode_param=extract_mode_param, + device=child_module.weight.device, + ) + if down_weight is None: + continue + down_weight = down_weight.to(dtype=dtype) + up_weight = up_weight.to(dtype=dtype) + bias_weight = None + if child_module.bias is not None: + bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) + + new_module = LoRMCon2d( + in_channels=down_weight.shape[1], + lorm_channels=lora_dim, + out_channels=up_weight.shape[0], + kernel_size=child_module.kernel_size, + dilation=child_module.dilation, + padding=child_module.padding, + padding_mode=child_module.padding_mode, + stride=child_module.stride, + bias=bias_weight is not None, + device=down_weight.device, + dtype=down_weight.dtype + ) + # replace the weights + new_module.down.weight.data = down_weight + new_module.up.weight.data = up_weight + if bias_weight is not None: + new_module.up.bias.data = bias_weight + + if new_module: + combined_name = f"{name}.{child_name}" + replace_module_by_path(unet, combined_name, new_module) + converted_modules.append(new_module) + num_replaced += 1 + layer_names_replaced.append( + f"{combined_name} - {format_with_commas(count_parameters(child_module))}") + + pbar.update(1) + pbar.close() + end_num_params = count_parameters(unet) + + def sorting_key(s): + # Extract the number part, remove commas, and convert to integer + return int(s.split("-")[1].strip().replace(",", "")) + + sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) + for layer_name in sorted_layer_names_replaced: + print(layer_name) + + print_lorm_extract_details( + start_num_params=start_num_params, + end_num_params=end_num_params, + num_replaced=num_replaced, + ) + + return converted_modules diff --git a/ai-toolkit/toolkit/losses.py b/ai-toolkit/toolkit/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..fef9310dfe58c53aae7930238af20eddb4ce0a65 --- /dev/null +++ b/ai-toolkit/toolkit/losses.py @@ -0,0 +1,113 @@ +import torch +from .llvae import LosslessLatentEncoder + + +def total_variation(image): + """ + Compute normalized total variation. + Inputs: + - image: PyTorch Variable of shape (N, C, H, W) + Returns: + - TV: total variation normalized by the number of elements + """ + n_elements = image.shape[1] * image.shape[2] * image.shape[3] + return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + + torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) + +def total_variation_deltas(image): + """ + Compute per-pixel total variation deltas. + Input: + - image: Tensor of shape (N, C, H, W) + Returns: + - Tensor with shape (N, C, H, W), padded to match input shape + """ + dh = torch.zeros_like(image) + dv = torch.zeros_like(image) + + dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]) + dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]) + + return dh + dv + + +class ComparativeTotalVariation(torch.nn.Module): + """ + Compute the comparative loss in tv between two images. to match their tv + """ + + def forward(self, pred, target): + return torch.abs(total_variation(pred) - total_variation(target)) + + +# Gradient penalty +def get_gradient_penalty(critic, real, fake, device): + with torch.autocast(device_type='cuda'): + real = real.float() + fake = fake.float() + alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float() + interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True) + if torch.isnan(interpolates).any(): + print('d_interpolates is nan') + d_interpolates = critic(interpolates) + fake = torch.ones(real.size(0), 1, device=device) + + if torch.isnan(d_interpolates).any(): + print('fake is nan') + gradients = torch.autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + # see if any are nan + if torch.isnan(gradients).any(): + print('gradients is nan') + + gradients = gradients.view(gradients.size(0), -1) + gradient_norm = gradients.norm(2, dim=1) + gradient_penalty = ((gradient_norm - 1) ** 2).mean() + return gradient_penalty.float() + + +class PatternLoss(torch.nn.Module): + def __init__(self, pattern_size=4, dtype=torch.float32): + super().__init__() + self.pattern_size = pattern_size + self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype) + + def forward(self, pred, target): + pred_latents = self.llvae_encoder(pred) + target_latents = self.llvae_encoder(target) + + matrix_pixels = self.pattern_size * self.pattern_size + + color_chans = pred_latents.shape[1] // 3 + # pytorch + r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1) + r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1) + + def separated_chan_loss(latent_chan): + nonlocal matrix_pixels + chan_mean = torch.mean(latent_chan, dim=[1, 2, 3]) + chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1) + chan_loss = None + for chan in chan_splits: + this_mean = torch.mean(chan, dim=[1, 2, 3]) + this_chan_loss = torch.abs(this_mean - chan_mean) + if chan_loss is None: + chan_loss = this_chan_loss + else: + chan_loss = chan_loss + this_chan_loss + chan_loss = chan_loss * (1 / matrix_pixels) + return chan_loss + + r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target)) + g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target)) + b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target)) + return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333 + + diff --git a/ai-toolkit/toolkit/lycoris_special.py b/ai-toolkit/toolkit/lycoris_special.py new file mode 100644 index 0000000000000000000000000000000000000000..84021b49cdd3853972721924c1f957203e17e49d --- /dev/null +++ b/ai-toolkit/toolkit/lycoris_special.py @@ -0,0 +1,373 @@ +import math +import os +from typing import Optional, Union, List, Type + +import torch +from lycoris.kohya import LycorisNetwork, LoConModule +from lycoris.modules.glora import GLoRAModule +from torch import nn +from transformers import CLIPTextModel +from torch.nn import functional as F +from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + +class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin): + def __init__( + self, + lora_name, org_module: nn.Module, + multiplier=1.0, + lora_dim=4, alpha=1, + dropout=0., rank_dropout=0., module_dropout=0., + use_cp=False, + network: 'LycorisSpecialNetwork' = None, + use_bias=False, + **kwargs, + ): + """ if alpha == 0 or None, alpha is rank (no scaling). """ + # call super of super + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.lora_dim = lora_dim + self.cp = False + + # check if parent has bias. if not force use_bias to False + if org_module.bias is None: + use_bias = False + + self.scalar = nn.Parameter(torch.tensor(0.0)) + orig_module_name = org_module.__class__.__name__ + if orig_module_name in CONV_MODULES: + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + self.down_op = F.conv2d + self.up_op = F.conv2d + if use_cp and k_size != (1, 1): + self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) + self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False) + self.cp = True + else: + self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False) + self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias) + elif orig_module_name in LINEAR_MODULES: + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + if orig_module_name == 'GroupNorm': + # RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32) + in_dim = org_module.num_channels + out_dim = org_module.num_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias) + else: + raise NotImplementedError + self.shape = org_module.weight.shape + + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = nn.Identity() + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lora_up.weight) + if self.cp: + torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5)) + + self.multiplier = multiplier + self.org_module = [org_module] + self.register_load_state_dict_post_hook(self.load_weight_hook) + + def load_weight_hook(self, *args, **kwargs): + self.scalar = nn.Parameter(torch.ones_like(self.scalar)) + + +class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + # 'UNet2DConditionModel', + # 'Conv2d', + # 'Timesteps', + # 'TimestepEmbedding', + # 'Linear', + # 'SiLU', + # 'ModuleList', + # 'DownBlock2D', + # 'ResnetBlock2D', # need + # 'GroupNorm', + # 'LoRACompatibleConv', + # 'LoRACompatibleLinear', + # 'Dropout', + # 'CrossAttnDownBlock2D', # needed + # 'Transformer2DModel', # maybe not, has duplicates + # 'BasicTransformerBlock', # duplicates + # 'LayerNorm', + # 'Attention', + # 'FeedForward', + # 'GEGLU', + # 'UpBlock2D', + # 'UNetMidBlock2DCrossAttn' + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + use_cp: Optional[bool] = False, + network_module: Type[object] = LoConSpecialModule, + train_unet: bool = True, + train_text_encoder: bool = True, + use_text_encoder_1: bool = True, + use_text_encoder_2: bool = True, + use_bias: bool = False, + is_lorm: bool = False, + **kwargs, + ) -> None: + # call ToolkitNetworkMixin super + ToolkitNetworkMixin.__init__( + self, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + is_lorm=is_lorm, + **kwargs + ) + # call the parent of the parent LycorisNetwork + torch.nn.Module.__init__(self) + + # LyCORIS unique stuff + if dropout is None: + dropout = 0 + if rank_dropout is None: + rank_dropout = 0 + if module_dropout is None: + module_dropout = 0 + self.train_unet = train_unet + self.train_text_encoder = train_text_encoder + + self.torch_multiplier = None + # triggers a tensor update + self.multiplier = multiplier + self.lora_dim = lora_dim + + if not self.ENABLE_CONV or conv_lora_dim is None: + conv_lora_dim = 0 + conv_alpha = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + print('Apply different lora dim for conv layer') + print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}') + elif self.conv_lora_dim == 0: + print('Disable conv layer') + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + print('Apply different alpha value for conv layer') + print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}') + + if 1 >= dropout >= 0: + print(f'Use Dropout value: {dropout}') + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[] + ) -> List[network_module]: + print('Create LyCORIS Module') + loras = [] + # remove this + named_modules = root_module.named_modules() + # add a few to tthe generator + + for name, module in named_modules: + module_name = module.__class__.__name__ + if module_name in target_replace_modules: + if module_name in self.MODULE_ALGO_MAP: + algo = self.MODULE_ALGO_MAP[module_name] + else: + algo = network_module + for child_name, child_module in module.named_modules(): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'): + print(f"{lora_name}") + + if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + elif child_module.__class__.__name__ in CONV_MODULES: + k_size, *_ = child_module.kernel_size + if k_size == 1 and lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + elif conv_lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.conv_lora_dim, self.conv_alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + else: + continue + else: + continue + loras.append(lora) + elif name in target_replace_names: + if name in self.NAME_ALGO_MAP: + algo = self.NAME_ALGO_MAP[name] + else: + algo = network_module + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + if module.__class__.__name__ == 'Linear' and lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + parent=module, + network=self, + use_bias=use_bias, + **kwargs + ) + elif module.__class__.__name__ == 'Conv2d': + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + elif conv_lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.conv_lora_dim, self.conv_alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + network=self, + parent=module, + use_bias=use_bias, + **kwargs + ) + else: + continue + else: + continue + loras.append(lora) + return loras + + if network_module == GLoRAModule: + print('GLoRA enabled, only train transformer') + # only train transformer (for GLoRA) + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + ] + LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = [] + + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + self.text_encoder_loras = [] + if self.train_text_encoder: + for i, te in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue + self.text_encoder_loras.extend(create_modules( + LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), + te, + LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + )) + print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.") + if self.train_unet: + self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet, + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE) + else: + self.unet_loras = [] + print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") + + self.weights_sd = None + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) diff --git a/ai-toolkit/toolkit/lycoris_utils.py b/ai-toolkit/toolkit/lycoris_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af11ee9ef52c9c0a42ac34afc3e3fa42c7c4b83d --- /dev/null +++ b/ai-toolkit/toolkit/lycoris_utils.py @@ -0,0 +1,536 @@ +# heavily based on https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/utils.py + +from typing import * + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torch.linalg as linalg + +from tqdm import tqdm +from collections import OrderedDict + + +def make_sparse(t: torch.Tensor, sparsity=0.95): + abs_t = torch.abs(t) + np_array = abs_t.detach().cpu().numpy() + quan = float(np.quantile(np_array, sparsity)) + sparse_t = t.masked_fill(abs_t < quan, 0) + return sparse_t + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', + is_cp=False, +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) + + if mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2 and not is_cp: + return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), 'low rank' + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = linalg.svd(weight) + + if mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), 'low rank' + + +def extract_diff( + base_model, + db_model, + mode='fixed', + linear_mode_param=0, + conv_mode_param=0, + extract_device='cpu', + use_bias=False, + sparsity=0.98, + small_conv=True, + linear_only=False, + extract_unet=True, + extract_text_encoder=True, +): + meta = OrderedDict() + + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D" + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + if linear_only: + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + ] + + if not extract_unet: + UNET_TARGET_REPLACE_MODULE = [] + UNET_TARGET_REPLACE_NAME = [] + + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + + if not extract_text_encoder: + TEXT_ENCODER_TARGET_REPLACE_MODULE = [] + + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[] + ): + loras = {} + temp = {} + temp_name = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = {} + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: + continue + temp[name][child_name] = child_module.weight + elif name in target_replace_names: + temp_name[name] = module.weight + + for name, module in tqdm(list(target_module.named_modules())): + if name in temp: + weights = temp[name] + for child_name, child_module in module.named_modules(): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + layer = child_module.__class__.__name__ + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: + root_weight = child_module.weight + if torch.allclose(root_weight, weights[child_name]): + continue + + if layer == 'Linear' or layer == 'LoRACompatibleLinear': + weight, decompose_mode = extract_linear( + (child_module.weight - weights[child_name]), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': + is_linear = (child_module.weight.shape[2] == 1 + and child_module.weight.shape[3] == 1) + if not is_linear and linear_only: + continue + weight, decompose_mode = extract_conv( + (child_module.weight - weights[child_name]), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == 'low rank': + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + 'fixed', dim, + extract_device, True + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() + diff = child_module.weight - torch.einsum( + 'i j k l, j r, p i -> p r k l', + extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) + ).detach().cpu().contiguous() + del extract_c + else: + continue + if decompose_mode == 'low rank': + loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() + loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() + loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f'{lora_name}.bias_indices'] = indices + loras[f'{lora_name}.bias_values'] = values + loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) + del extract_a, extract_b, diff + elif decompose_mode == 'full': + loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() + else: + raise NotImplementedError + elif name in temp_name: + weights = temp_name[name] + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + layer = module.__class__.__name__ + + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: + root_weight = module.weight + if torch.allclose(root_weight, weights): + continue + + if layer == 'Linear' or layer == 'LoRACompatibleLinear': + weight, decompose_mode = extract_linear( + (root_weight - weights), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': + is_linear = ( + root_weight.shape[2] == 1 + and root_weight.shape[3] == 1 + ) + if not is_linear and linear_only: + continue + weight, decompose_mode = extract_conv( + (root_weight - weights), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == 'low rank': + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == 'low rank': + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + 'fixed', dim, + extract_device, True + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() + diff = root_weight - torch.einsum( + 'i j k l, j r, p i -> p r k l', + extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) + ).detach().cpu().contiguous() + del extract_c + else: + continue + if decompose_mode == 'low rank': + loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() + loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() + loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f'{lora_name}.bias_indices'] = indices + loras[f'{lora_name}.bias_values'] = values + loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) + del extract_a, extract_b, diff + elif decompose_mode == 'full': + loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() + else: + raise NotImplementedError + return loras + + text_encoder_loras = make_state_dict( + LORA_PREFIX_TEXT_ENCODER, + base_model[0], db_model[0], + TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + + unet_loras = make_state_dict( + LORA_PREFIX_UNET, + base_model[2], db_model[2], + UNET_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + print(len(text_encoder_loras), len(unet_loras)) + # the | will + return (text_encoder_loras | unet_loras), meta + + +def get_module( + lyco_state_dict: Dict, + lora_name +): + if f'{lora_name}.lora_up.weight' in lyco_state_dict: + up = lyco_state_dict[f'{lora_name}.lora_up.weight'] + down = lyco_state_dict[f'{lora_name}.lora_down.weight'] + mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'locon', (up, down, mid, alpha) + elif f'{lora_name}.hada_w1_a' in lyco_state_dict: + w1a = lyco_state_dict[f'{lora_name}.hada_w1_a'] + w1b = lyco_state_dict[f'{lora_name}.hada_w1_b'] + w2a = lyco_state_dict[f'{lora_name}.hada_w2_a'] + w2b = lyco_state_dict[f'{lora_name}.hada_w2_b'] + t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None) + t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha) + elif f'{lora_name}.weight' in lyco_state_dict: + weight = lyco_state_dict[f'{lora_name}.weight'] + on_input = lyco_state_dict.get(f'{lora_name}.on_input', False) + return 'ia3', (weight, on_input) + elif (f'{lora_name}.lokr_w1' in lyco_state_dict + or f'{lora_name}.lokr_w1_a' in lyco_state_dict): + w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None) + w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None) + w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None) + w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None) + w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None) + w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None) + t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None) + t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None) + alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) + return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha) + elif f'{lora_name}.diff' in lyco_state_dict: + return 'full', lyco_state_dict[f'{lora_name}.diff'] + else: + return 'None', () + + +def cp_weight_from_conv( + up, down, mid +): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down) + + +def cp_weight( + wa, wb, t +): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +@torch.no_grad() +def rebuild_weight(module_type, params, orig_weight, scale=1): + if orig_weight is None: + return orig_weight + merged = orig_weight + if module_type == 'locon': + up, down, mid, alpha = params + if alpha is not None: + scale *= alpha / up.size(1) + if mid is not None: + rebuild = cp_weight_from_conv(up, down, mid) + else: + rebuild = up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1) + merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale + del up, down, mid, alpha, params, rebuild + elif module_type == 'hada': + w1a, w1b, w2a, w2b, t1, t2, alpha = params + if alpha is not None: + scale *= alpha / w1b.size(0) + if t1 is not None: + rebuild1 = cp_weight(w1a, w1b, t1) + else: + rebuild1 = w1a @ w1b + if t2 is not None: + rebuild2 = cp_weight(w2a, w2b, t2) + else: + rebuild2 = w2a @ w2b + rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2 + elif module_type == 'ia3': + weight, on_input = params + if not on_input: + weight = weight.reshape(-1, 1) + merged = orig_weight + weight * orig_weight * scale + del weight, on_input, params + elif module_type == 'kron': + w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params + if alpha is not None and (w1b is not None or w2b is not None): + scale *= alpha / (w1b.size(0) if w1b else w2b.size(0)) + if w1a is not None and w1b is not None: + if t1: + w1 = cp_weight(w1a, w1b, t1) + else: + w1 = w1a @ w1b + if w2a is not None and w2b is not None: + if t2: + w2 = cp_weight(w2a, w2b, t2) + else: + w2 = w2a @ w2b + rebuild = torch.kron(w1, w2).reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild + elif module_type == 'full': + rebuild = params.reshape(orig_weight.shape) + merged = orig_weight + rebuild * scale + del params, rebuild + + return merged + + +def merge( + base_model, + lyco_state_dict, + scale: float = 1.0, + device='cpu' +): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D" + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + merged = 0 + + def merge_state_dict( + prefix, + root_module: torch.nn.Module, + lyco_state_dict: Dict[str, torch.Tensor], + target_replace_modules, + target_replace_names=[] + ): + nonlocal merged + for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', + 'LoRACompatibleConv'}: + continue + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + + result = rebuild_weight(*get_module( + lyco_state_dict, lora_name + ), getattr(child_module, 'weight'), scale) + if result is not None: + merged += 1 + child_module.requires_grad_(False) + child_module.weight.copy_(result) + elif name in target_replace_names: + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + + result = rebuild_weight(*get_module( + lyco_state_dict, lora_name + ), getattr(module, 'weight'), scale) + if result is not None: + merged += 1 + module.requires_grad_(False) + module.weight.copy_(result) + + if device == 'cpu': + for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'): + lyco_state_dict[k] = v.float() + + merge_state_dict( + LORA_PREFIX_TEXT_ENCODER, + base_model[0], + lyco_state_dict, + TEXT_ENCODER_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + merge_state_dict( + LORA_PREFIX_UNET, + base_model[2], + lyco_state_dict, + UNET_TARGET_REPLACE_MODULE, + UNET_TARGET_REPLACE_NAME + ) + print(f'{merged} Modules been merged') diff --git a/ai-toolkit/toolkit/memory_management/__init__.py b/ai-toolkit/toolkit/memory_management/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2eeef37d77844ea9acd0c32f33e4a8d581efd4f9 --- /dev/null +++ b/ai-toolkit/toolkit/memory_management/__init__.py @@ -0,0 +1 @@ +from .manager import MemoryManager \ No newline at end of file diff --git a/ai-toolkit/toolkit/memory_management/manager.py b/ai-toolkit/toolkit/memory_management/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0785f06ca6a8ecd293d78a984ae9d751730b2298 --- /dev/null +++ b/ai-toolkit/toolkit/memory_management/manager.py @@ -0,0 +1,226 @@ +import torch +from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager, _DEVICE_STATE +import random + +LINEAR_MODULES = [ + "Linear", + "LoRACompatibleLinear", + "QLinear", +] +CONV_MODULES = [ + "Conv2d", + "LoRACompatibleConv", + "QConv2d", +] + +UNMANAGED_MODULES = [ + "LayerNorm", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "Embedding", + "EmbeddingBag", + "RNNBase", + "LSTM", + "GRU", + "RNN", + "Conv3d" +] + +UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"] + + +class MemoryManager: + def __init__( + self, + module: torch.nn.Module, + process_device: torch.device = torch.device("cpu"), + ): + self.module: torch.nn.Module = module + self.process_device: torch.device = process_device + self.unmanaged_modules: list[torch.nn.Module] = [] + + def memory_managed_to(self, *args, **kwargs): + # first move all the unmanaged modules + for module in self.unmanaged_modules: + if isinstance(module, torch.nn.Parameter): + # Parameter cannot move this way + module.data = module.data.to(*args, **kwargs) + else: + module.to(*args, **kwargs) + # check for a dtype argument + dtype = None + if "dtype" in kwargs: + dtype = kwargs["dtype"] + elif len(args) > 0: + for i, arg in enumerate(args): + if isinstance(arg, torch.dtype): + dtype = arg + break + if dtype is not None: + return self.module._mm_to(dtype=dtype) + return self.module + + @classmethod + def attach( + cls, + module: torch.nn.Module, + device: torch.device, + offload_percent: float = 1.0, + ignore_modules: list[torch.nn.Module] = [] + ): + if hasattr(module, "_memory_manager"): + # already attached + return + + module._memory_manager = cls(module, device) + + # override the to method to handle memory management + module._mm_to = module.to + module.to = module._memory_manager.memory_managed_to + + # add ignore modules to unmanaged list + for im in ignore_modules: + module._memory_manager.unmanaged_modules.append(im) + + # count ignore modules as processed + modules_processed = [x for x in ignore_modules] + # attach to all modules + for name, sub_module in module.named_modules(): + for child_name, child_module in sub_module.named_modules(): + if ( + child_module.__class__.__name__ in LINEAR_MODULES + and child_module not in modules_processed + ): + skip = False + if offload_percent < 1.0: + # randomly skip some modules + if random.random() > offload_percent: + skip = True + if skip: + module._memory_manager.unmanaged_modules.append(child_module) + else: + # linear + LinearLayerMemoryManager.attach( + child_module, module._memory_manager + ) + # attach to ARA as well + if hasattr(child_module, "ara_lora_ref"): + ara = child_module.ara_lora_ref() + if ara not in modules_processed: + MemoryManager.attach( + ara, + device, + ) + modules_processed.append(child_module) + elif ( + child_module.__class__.__name__ in CONV_MODULES + and child_module not in modules_processed + ): + skip = False + if offload_percent < 1.0: + # randomly skip some modules + if random.random() > offload_percent: + skip = True + if skip: + module._memory_manager.unmanaged_modules.append(child_module) + else: + # conv + ConvLayerMemoryManager.attach( + child_module, module._memory_manager + ) + # attach to ARA as well + if hasattr(child_module, "ara_lora_ref"): + ara = child_module.ara_lora_ref() + if ara not in modules_processed: + MemoryManager.attach( + ara, + device, + ) + modules_processed.append(ara) + modules_processed.append(child_module) + elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( + inc in child_module.__class__.__name__ + for inc in UNMANAGED_MODULES_INCLUDES + ): + # unmanaged + module._memory_manager.unmanaged_modules.append(child_module) + else: + continue + + @classmethod + def detach(cls, module: torch.nn.Module): + """ + Reverse of attach(). Moves unmanaged modules back to CPU, restores the + original .to() and forward methods on all child layers, unpins CPU weight + tensors, and clears the global CUDA device state. + + Call this before unloading/replacing a module that had attach() applied. + """ + if not hasattr(module, "_memory_manager"): + return + + for unmanaged in module._memory_manager.unmanaged_modules: + try: + if isinstance(unmanaged, torch.nn.Parameter): + unmanaged.data = unmanaged.data.to('cpu') + else: + unmanaged.to('cpu') + except Exception: + pass + + if hasattr(module, "_mm_to"): + module.to = module._mm_to + del module._mm_to + + del module._memory_manager + + for child in module.modules(): + lmm = getattr(child, "_layer_memory_manager", None) + if lmm is None: + continue + + original_forward = getattr(lmm, "_original_forward", None) + if original_forward is not None: + if hasattr(child, "ara_lora_ref"): + ara = child.ara_lora_ref() + if ara is not None: + ara.org_forward = original_forward + else: + child.forward = original_forward + + for param_name in ("weight", "bias"): + param = getattr(child, param_name, None) + if param is None or not isinstance(param, torch.nn.Parameter): + continue + try: + if param.data.is_pinned(): + object.__setattr__( + child, + param_name, + torch.nn.Parameter( + param.data.clone(), + requires_grad=param.requires_grad, + ), + ) + except Exception: + pass + + del child._layer_memory_manager + if hasattr(child, "_memory_management_device"): + del child._memory_management_device + if hasattr(child, "_is_memory_managed"): + del child._is_memory_managed + + keys_to_delete = [ + dev for dev in _DEVICE_STATE + if isinstance(dev, torch.device) and dev.type == "cuda" + ] + for key in keys_to_delete: + del _DEVICE_STATE[key] + + torch.cuda.empty_cache() diff --git a/ai-toolkit/toolkit/memory_management/manager_modules.py b/ai-toolkit/toolkit/memory_management/manager_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea5b0c60f1f2a254352e278511797e443555149 --- /dev/null +++ b/ai-toolkit/toolkit/memory_management/manager_modules.py @@ -0,0 +1,694 @@ +""" +This code was heavily inspired by the work of Lodestone-Rock, pretty much all credit goes +to them. The original code can be found here: +https://github.com/lodestone-rock/RamTorch/blob/main/ramtorch/modules/linear.py + +I simply modified it to work with a memory management model and with AI Toolkit's models +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import TYPE_CHECKING, Optional, Tuple +from torch.overrides import has_torch_function_unary # (ADD) torchao detection + +if TYPE_CHECKING: + from .manager import MemoryManager + +# --- Per-device global state registry --- +_DEVICE_STATE = {} + +# How many layers deep to prefetch weights. The old ping-pong used 2 slots, which +# only lets one transfer overlap one compute (1-deep). A deeper ring lets Python +# enqueue several layers ahead so the H2D stream stays saturated instead of +# stalling on a per-layer sync. Override with AI_TOOLKIT_OFFLOAD_DEPTH. +PIPELINE_DEPTH = int(os.environ.get("AI_TOOLKIT_OFFLOAD_DEPTH", "4")) + + +def _get_device_state(device: torch.device): + """Get or initialize per-device state.""" + if isinstance(device, str): + device = torch.device(device) + + # CPU path needs no CUDA state + if device.type != "cuda": + if device not in _DEVICE_STATE: + _DEVICE_STATE[device] = {} + return _DEVICE_STATE[device] + + if device not in _DEVICE_STATE: + d = max(2, PIPELINE_DEPTH) + with torch.cuda.device(device): + _DEVICE_STATE[device] = { + "depth": d, + # streams + "transfer_stream": torch.cuda.Stream(device=device), + "transfer_grad_stream": torch.cuda.Stream(device=device), + # forward weight ring: slot_ready = H2D done, slot_free = compute + # that consumed the slot done (so it can be overwritten). + "w_buffers": [None] * d, + "b_buffers": [None] * d, + "fwd_slot_ready": [torch.cuda.Event() for _ in range(d)], + "fwd_slot_free": [torch.cuda.Event() for _ in range(d)], + "forward_clk": 0, + # backward weight ring (re-fetch for grad-input). + "w_bwd_buffers": [None] * d, + "bwd_slot_ready": [torch.cuda.Event() for _ in range(d)], + "bwd_slot_free": [torch.cuda.Event() for _ in range(d)], + "backward_clk": 0, + # backward grad-staging ring (device-side grads -> CPU). + "w_grad_buffers": [None] * d, + "b_grad_buffers": [None] * d, + "grad_compute_done": [torch.cuda.Event() for _ in range(d)], + "grad_xfer_done": [torch.cuda.Event() for _ in range(d)], + } + return _DEVICE_STATE[device] + + +# ---- ring-buffer staging helpers ----------------------------------------- +# +# Each transfer waits only on the event for the *specific slot* it is about to +# overwrite (the compute that used that slot D layers ago), not on a single +# global "compute started" event. With D slots that prior compute is long done, +# so the transfer stream never actually stalls and stays D layers ahead of +# compute. This is the deeper-pipeline + relaxed-dependency change in one. + + +def _stage_forward_weight(state, device, materialize, weight_cpu, bias_cpu): + """H2D the next forward weight (+bias) into its ring slot; return (idx, w, b). + Caller runs compute, then calls _release_forward_slot(state, idx).""" + d = state["depth"] + idx = state["forward_clk"] + state["forward_clk"] = (idx + 1) % d + ts = state["transfer_stream"] + with torch.cuda.stream(ts): + ts.wait_event(state["fwd_slot_free"][idx]) + state["w_buffers"][idx] = materialize(weight_cpu, device) + state["b_buffers"][idx] = ( + bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None + ) + state["fwd_slot_ready"][idx].record() + torch.cuda.current_stream().wait_event(state["fwd_slot_ready"][idx]) + return idx, state["w_buffers"][idx], state["b_buffers"][idx] + + +def _release_forward_slot(state, idx): + # Slot is reusable once the compute stream finishes the op that read it. + state["fwd_slot_free"][idx].record() + + +def _stage_backward_weight(state, device, materialize, weight_cpu): + """H2D the next backward weight into its ring slot; return (idx, w). + Caller runs grad-input compute, then _release_backward_weight_slot.""" + d = state["depth"] + idx = state["backward_clk"] + state["backward_clk"] = (idx + 1) % d + ts = state["transfer_stream"] + with torch.cuda.stream(ts): + ts.wait_event(state["bwd_slot_free"][idx]) + state["w_bwd_buffers"][idx] = materialize(weight_cpu) + state["bwd_slot_ready"][idx].record() + torch.cuda.current_stream().wait_event(state["bwd_slot_ready"][idx]) + return idx, state["w_bwd_buffers"][idx] + + +def _release_backward_weight_slot(state, idx): + state["bwd_slot_free"][idx].record() + + +def _stage_grads_to_cpu(state, idx, grad_w_gpu, grad_b_gpu): + """Copy freshly-computed device grads (in staging slot idx) to CPU on the + grad stream, overlapping the next H2D. Returns (grad_w_cpu, grad_b_cpu).""" + gs = state["transfer_grad_stream"] + state["grad_compute_done"][idx].record() # on the compute stream + grad_w_cpu = grad_b_cpu = None + with torch.cuda.stream(gs): + gs.wait_event(state["grad_compute_done"][idx]) + if grad_w_gpu is not None: + grad_w_cpu = grad_w_gpu.to("cpu", non_blocking=True) + if grad_b_gpu is not None: + grad_b_cpu = grad_b_gpu.to("cpu", non_blocking=True) + state["grad_xfer_done"][idx].record() + return grad_w_cpu, grad_b_cpu + + +# (ADD) detect torchao wrapper tensors +def _is_ao_quantized_tensor(t: Optional[torch.Tensor]) -> bool: + if t is None: + return False + try: + if has_torch_function_unary(t): + return t.__class__.__module__.startswith("torchao.") + except Exception: + pass + for attr in ( + "_scale", + "_scales", + "_zero_point", + "_zp", + "_block_size", + "_group_size", + "_pack_dim", + ): + if hasattr(t, attr): + return True + return False + + +def _is_quantized_tensor(t: Optional[torch.Tensor]) -> bool: + if t is None: + return False + # torch quantized tensors + try: + if torch.is_quantized(t): # type: ignore[attr-defined] + return True + except Exception: + pass + # (ADD) torchao quantized wrappers + if _is_ao_quantized_tensor(t): + return True + # packed/int formats (weight-only) + return not t.dtype.is_floating_point + + +def _pin_inner_tensors(t: torch.Tensor) -> None: + """Pin the leaf storage of a tensor-subclass (e.g. torchao float8) in place. + + Quantized wrappers can't be pin_memory()'d directly, but they expose their + real data as inner tensors via __tensor_flatten__. Pinning those lets the + per-layer H2D bounce run async and overlap with compute instead of blocking. + """ + try: + names, _ = t.__tensor_flatten__() + except Exception: + return + for name in names: + inner = getattr(t, name, None) + if inner is None: + continue + if hasattr(inner, "__tensor_flatten__"): + _pin_inner_tensors(inner) # recurse: AQT -> tensor_impl -> data/scale + elif ( + isinstance(inner, torch.Tensor) + and inner.device.type == "cpu" + and not inner.is_pinned() + ): + try: + setattr(t, name, inner.pin_memory()) + except Exception: + pass + + +def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if t is None: + return None + if t.device.type != "cpu": + try: + t = t.to("cpu", copy=True) + except Exception: + t = t.to("cpu") + # Quantized wrappers can't be pin_memory()'d directly, but pinning their + # inner storage gives the same async-transfer benefit. + if _is_quantized_tensor(t): + if torch.cuda.is_available(): + _pin_inner_tensors(t) + return t + if torch.cuda.is_available(): + try: + t = t.pin_memory() + except RuntimeError: + pass + return t + + +def _move_params_to_cpu_and_pin(module: nn.Module): + """Force parameters to CPU (+pinned) so we can 'bounce' them per forward/backward.""" + with torch.no_grad(): + for name in ("weight", "bias"): + param = getattr(module, name, None) + if not isinstance(param, nn.Parameter): + continue + cpu_data = _ensure_cpu_pinned(param.data).detach() + if _is_quantized_tensor(param.data): + # Tensor-subclass weights (e.g. torchao float8 AffineQuantizedTensor) + # ignore `param.data = ...`: the wrapper reports CPU but its inner + # storage stays on the GPU, so the weight never actually offloads. + # Replace the whole Parameter so the device move sticks. + setattr( + module, + name, + nn.Parameter(cpu_data, requires_grad=param.requires_grad), + ) + else: + param.data = cpu_data + + +# ========================== +# Autograd functions (CUDA) +# ========================== + + +class _BouncingLinearFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device): + # choose compute dtype to match activations + target_dtype = ( + x.dtype + if x.dtype in (torch.bfloat16, torch.float16, torch.float32) + else torch.bfloat16 + ) + + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_linear_weight(cpu_w, dev): + if _is_quantized_tensor(cpu_w): + # move quantized wrapper to GPU -> dequantize on GPU -> cast on GPU + w_q_gpu = cpu_w.to(dev, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w_gpu = cpu_w.to(dev, non_blocking=True) + return w_gpu + + if device.type != "cuda": + out = F.linear( + x.to("cpu"), + _materialize_linear_weight(weight_cpu, torch.device("cpu")), + bias_cpu, + ) + ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) + ctx.device = torch.device("cpu") + return out.to(x.device) + + state = _get_device_state(device) + idx, w_gpu, b_gpu = _stage_forward_weight( + state, device, _materialize_linear_weight, weight_cpu, bias_cpu + ) + out = F.linear(x, w_gpu, b_gpu) + _release_forward_slot(state, idx) + + ctx.save_for_backward(x, weight_cpu, bias_cpu) + ctx.device = device + ctx.target_dtype = target_dtype + return out + + @staticmethod + def backward(ctx, grad_out): + x, weight_cpu, bias_cpu = ctx.saved_tensors + device = ctx.device + target_dtype = getattr(ctx, "target_dtype", grad_out.dtype) + + if device.type != "cuda": + go_cpu = grad_out.to("cpu") + x_cpu = x.to("cpu") + w_mat = ( + weight_cpu.dequantize() + if _is_quantized_tensor(weight_cpu) + else weight_cpu + ) + if w_mat.dtype != target_dtype and target_dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ): + w_mat = w_mat.to(target_dtype) + grad_input = go_cpu @ w_mat + grad_weight = ( + go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2) + if getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + else None + ) + grad_bias = ( + go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1))) + if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False)) + else None + ) + return grad_input.to(grad_out.device), grad_weight, grad_bias, None + + state = _get_device_state(device) + + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_for_bwd(cpu_w): + if _is_quantized_tensor(cpu_w): + w_q_gpu = cpu_w.to(device, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w = cpu_w.to(device, non_blocking=True) + return w + + idx, w_bwd = _stage_backward_weight( + state, device, _materialize_for_bwd, weight_cpu + ) + + # grad wrt input (GPU) + grad_input = grad_out.to(dtype=target_dtype) @ w_bwd + _release_backward_weight_slot(state, idx) + + # compute grads if float masters exist (frozen/quantized bases skip this) + grad_weight = None + grad_bias = None + need_w = ( + getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + ) + need_b = bias_cpu is not None and getattr(bias_cpu, "requires_grad", False) + if need_w or need_b: + # ensure the prior grad D2H using this staging slot finished + torch.cuda.current_stream().wait_event(state["grad_xfer_done"][idx]) + w_grad_gpu = b_grad_gpu = None + if need_w: + w_grad_gpu = grad_out.flatten(0, -2).T @ x.flatten(0, -2) + state["w_grad_buffers"][idx] = w_grad_gpu + if need_b: + b_grad_gpu = grad_out.sum(dim=tuple(range(grad_out.ndim - 1))) + state["b_grad_buffers"][idx] = b_grad_gpu + grad_weight, grad_bias = _stage_grads_to_cpu( + state, idx, w_grad_gpu, b_grad_gpu + ) + + return grad_input.to(dtype=grad_out.dtype), grad_weight, grad_bias, None + + +class _BouncingConv2dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight_cpu, + bias_cpu, + device: torch.device, + stride: Tuple[int, int], + padding: Tuple[int, int], + dilation: Tuple[int, int], + groups: int, + ): + target_dtype = ( + x.dtype + if x.dtype in (torch.bfloat16, torch.float16, torch.float32) + else torch.bfloat16 + ) + + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_conv_weight(cpu_w, dev): + if _is_quantized_tensor(cpu_w): + w_q_gpu = cpu_w.to(dev, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w_gpu = cpu_w.to(dev, non_blocking=True) + return w_gpu + + if device.type != "cuda": + out = F.conv2d( + x.to("cpu"), + _materialize_conv_weight(weight_cpu, torch.device("cpu")), + bias_cpu, + stride, + padding, + dilation, + groups, + ) + ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) + ctx.meta = ("cpu", stride, padding, dilation, groups, target_dtype) + return out.to(x.device) + + state = _get_device_state(device) + idx, w_gpu, b_gpu = _stage_forward_weight( + state, device, _materialize_conv_weight, weight_cpu, bias_cpu + ) + out = F.conv2d(x, w_gpu, b_gpu, stride, padding, dilation, groups) + _release_forward_slot(state, idx) + + ctx.save_for_backward(x, weight_cpu, bias_cpu) + ctx.meta = (device, stride, padding, dilation, groups, target_dtype) + return out + + @staticmethod + def backward(ctx, grad_out): + x, weight_cpu, bias_cpu = ctx.saved_tensors + device, stride, padding, dilation, groups, target_dtype = ctx.meta + + if ( + isinstance(device, torch.device) and device.type != "cuda" + ) or device == "cpu": + go = grad_out.to("cpu") + x_cpu = x.to("cpu") + w_cpu = ( + weight_cpu.dequantize() + if _is_quantized_tensor(weight_cpu) + else weight_cpu + ) + if w_cpu.dtype != target_dtype and target_dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ): + w_cpu = w_cpu.to(target_dtype) + from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore + + grad_input = conv2d_input( + x_cpu.shape, + w_cpu, + go, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + grad_weight = ( + conv2d_weight( + x_cpu, + w_cpu.shape, + go, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + if getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + else None + ) + grad_bias = ( + go.sum(dim=(0, 2, 3)) + if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False)) + else None + ) + return ( + grad_input.to(grad_out.device), + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) + + state = _get_device_state(device) + + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_for_bwd(cpu_w): + if _is_quantized_tensor(cpu_w): + w_q_gpu = cpu_w.to(device, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w = cpu_w.to(device, non_blocking=True) + return w + + idx, w_bwd = _stage_backward_weight( + state, device, _materialize_for_bwd, weight_cpu + ) + + from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore + + grad_input = conv2d_input( + x.shape, + w_bwd, + grad_out.to(dtype=target_dtype), + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + _release_backward_weight_slot(state, idx) + + # Compute heavy grads on GPU into staging buffers (frozen bases skip this) + grad_weight = None + grad_bias = None + need_w = ( + getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + ) + need_b = bias_cpu is not None and getattr(bias_cpu, "requires_grad", False) + if need_w or need_b: + torch.cuda.current_stream().wait_event(state["grad_xfer_done"][idx]) + w_grad_gpu = b_grad_gpu = None + if need_w: + w_grad_gpu = conv2d_weight( + x, + weight_cpu.shape, + grad_out, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + state["w_grad_buffers"][idx] = w_grad_gpu + if need_b: + b_grad_gpu = grad_out.sum(dim=(0, 2, 3)) + state["b_grad_buffers"][idx] = b_grad_gpu + grad_weight, grad_bias = _stage_grads_to_cpu( + state, idx, w_grad_gpu, b_grad_gpu + ) + + return ( + grad_input.to(dtype=grad_out.dtype), + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) + + +class BaseLayerMemoryManager: + def __init__( + self, + module: nn.Module, + manager: "MemoryManager", + ): + self.module: nn.Module = module + self.manager: "MemoryManager" = manager + + @classmethod + def attach(cls, module: nn.Module, manager: "MemoryManager"): + if hasattr(module, "_layer_memory_manager"): + return + module._layer_memory_manager = cls(module, manager) + + # mark parameters as memory managed + for param in module.parameters(recurse=False): + param._is_memory_managed = True + + +class LinearLayerMemoryManager(BaseLayerMemoryManager): + def __init__( + self, + module: nn.Module, + manager: "MemoryManager", + ): + super().__init__(module, manager) + + # 1) Move params to CPU + pin memory for fast H2D + _move_params_to_cpu_and_pin(self.module) + + # 2) Hijack forward + if hasattr(self.module, "ara_lora_ref"): + # ARA, we need to replace the lora forward + self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward") + else: + self._original_forward = getattr(self.module, "forward") + + def _mm_forward(x, *args, **kwargs): + # ensure we only use expected signature (Linear: x) + if args or kwargs: + # fall back to original if a custom signature is used + return self._original_forward(x, *args, **kwargs) + + weight_cpu = self.module.weight + bias_cpu = getattr(self.module, "bias", None) + device = self.manager.process_device + + # NOTE: do NOT move params to device here; autograd fn streams & bounces them + return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device) + + if hasattr(self.module, "ara_lora_ref"): + self.module.ara_lora_ref().org_forward = _mm_forward + else: + self.module.forward = _mm_forward + + self.module._memory_management_device = self.manager.process_device + + +class ConvLayerMemoryManager(BaseLayerMemoryManager): + def __init__( + self, + module: nn.Module, + manager: "MemoryManager", + ): + super().__init__(module, manager) + + # 1) Move params to CPU + pin memory for fast H2D + _move_params_to_cpu_and_pin(self.module) + + # Cache static conv attributes from the module + stride = ( + self.module.stride + if isinstance(self.module.stride, tuple) + else (self.module.stride, self.module.stride) + ) + padding = ( + self.module.padding + if isinstance(self.module.padding, tuple) + else (self.module.padding, self.module.padding) + ) + dilation = ( + self.module.dilation + if isinstance(self.module.dilation, tuple) + else (self.module.dilation, self.module.dilation) + ) + groups = self.module.groups + + # 2) Hijack forward + if hasattr(self.module, "ara_lora_ref"): + # ARA, we need to replace the lora forward + self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward") + else: + self._original_forward = getattr(self.module, "forward") + + def _mm_forward(x, *args, **kwargs): + # Support the typical Conv2d(x) call; if user passes uncommon extras, fallback. + if args or kwargs: + return self._original_forward(x, *args, **kwargs) + + weight_cpu = self.module.weight + bias_cpu = getattr(self.module, "bias", None) + device = self.manager.process_device + + return _BouncingConv2dFn.apply( + x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups + ) + + if hasattr(self.module, "ara_lora_ref"): + self.module.ara_lora_ref().org_forward = _mm_forward + else: + self.module.forward = _mm_forward + + self.module._memory_management_device = self.manager.process_device diff --git a/ai-toolkit/toolkit/memory_management/test_memory_manager.py b/ai-toolkit/toolkit/memory_management/test_memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a6430867d19ed7e6ebf594de72f7a2667dc5c369 --- /dev/null +++ b/ai-toolkit/toolkit/memory_management/test_memory_manager.py @@ -0,0 +1,333 @@ +""" +Memory-manager (layer offloading) benchmark on a ~1B parameter diffusion-style +transformer. The base model is frozen and a LoRA is trained on top of it (the +realistic training setup), so only the LoRA params get grads/optimizer state. + +Reports speed (ms/step) and peak VRAM for the 2x2 matrix of: + + - bfloat16 base vs bfloat16 + float8-quantized base (torchao weight-only) + - no offloading vs 100% offloading + +The MemoryManager keeps the frozen base weights pinned on the CPU and streams +them onto the GPU per forward/backward (dequantizing float8 weights on the GPU). +The LoRA wraps each base linear, so its forward calls the bounced base forward +and adds the low-rank update. This trades VRAM for PCIe traffic, so the table +shows what that trade actually costs. + +Run directly: `python test_memory_manager.py` +""" +import contextlib +import gc +import io +import os +import sys +import threading +import time + +import psutil +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RamMonitor: + """Sample process RSS in a background thread and track the peak. Pinned + CPU weights (from offloading) live in RSS, so this captures the host-RAM + cost the GPU-side peak doesn't see.""" + + def __init__(self, interval: float = 0.005): + self.interval = interval + self._proc = psutil.Process() + self.peak = 0 + + def _run(self): + while not self._stop: + self.peak = max(self.peak, self._proc.memory_info().rss) + time.sleep(self.interval) + + def __enter__(self): + self.peak = self._proc.memory_info().rss + self._stop = False + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + return self + + def __exit__(self, *exc): + self._stop = True + self._thread.join() + +# Allow running this file directly without setting PYTHONPATH. +# Toolkit imports happen inside main() so they pick this up. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + + +# ---- model --------------------------------------------------------------- + +class TransformerBlock(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int): + super().__init__() + self.n_heads = n_heads + self.d_head = d_model // n_heads + self.ln1 = nn.LayerNorm(d_model) + self.q = nn.Linear(d_model, d_model, bias=False) + self.k = nn.Linear(d_model, d_model, bias=False) + self.v = nn.Linear(d_model, d_model, bias=False) + self.o = nn.Linear(d_model, d_model, bias=False) + self.ln2 = nn.LayerNorm(d_model) + self.ffn_up = nn.Linear(d_model, d_ff, bias=False) + self.ffn_down = nn.Linear(d_ff, d_model, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, S, D = x.shape + h = self.ln1(x) + q = self.q(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) + k = self.k(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) + v = self.v(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) + a = F.scaled_dot_product_attention(q, k, v) + a = a.transpose(1, 2).contiguous().view(B, S, D) + x = x + self.o(a) + h = self.ln2(x) + x = x + self.ffn_down(F.gelu(self.ffn_up(h))) + return x + + +class Transformer(nn.Module): + def __init__(self, d_model=2048, n_heads=16, n_layers=24, d_ff=8192): + super().__init__() + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers) + ]) + self.norm = nn.LayerNorm(d_model) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for b in self.blocks: + # Gate on is_grad_enabled (not self.training): checkpointing only + # helps and only works when a backward will actually be run. + if self.gradient_checkpointing and torch.is_grad_enabled(): + x = torch.utils.checkpoint.checkpoint(b, x, use_reentrant=False) + else: + x = b(x) + return self.norm(x) + + +# ---- benchmark ----------------------------------------------------------- + +DEVICE = torch.device("cuda") +DTYPE = torch.bfloat16 +QTYPE = "float8" +D_MODEL = 2048 +N_HEADS = 16 +N_LAYERS = 24 +D_FF = 8192 +BATCH = 1 +SEQ = 1024 +WARMUP = 3 +ITERS = 10 +LORA_RANK = 32 +LR = 1e-4 + +# Full matrix: {bf16, float8} x {no offload, 100% offload} x {ckpt on, ckpt off}. +# Offloading parks weights in CPU RAM; turning off checkpointing keeps activations +# resident in VRAM. We report peak VRAM *and* peak system RAM so both show up. +# (label, quantize, offload_percent, grad_checkpointing) +RUNS = [] +for _do_q, _qlabel in [(False, "bf16"), (True, "float8")]: + for _off, _olabel in [(None, ""), (1.0, "+off")]: + for _ckpt in [True, False]: + _label = f"{_qlabel}{_olabel} ckpt={'on' if _ckpt else 'off'}" + RUNS.append((_label, _do_q, _off, _ckpt)) + + +def build_model(): + torch.manual_seed(0) + # Build on CPU; the caller decides how it reaches the GPU. + return Transformer(D_MODEL, N_HEADS, N_LAYERS, D_FF).to(dtype=DTYPE) + + +def build_lora(transformer): + """Attach a trainable LoRA to the (frozen) transformer, the same way the + trainer does it. Returns the network; its forward hijacks each base linear.""" + from toolkit.config_modules import NetworkConfig + from toolkit.lora_special import LoRASpecialNetwork + + network_config = NetworkConfig( + type="lora", + linear=LORA_RANK, + linear_alpha=LORA_RANK, + transformer_only=True, + ) + LoRASpecialNetwork.LORA_PREFIX_UNET = "lora_transformer" + network = LoRASpecialNetwork( + text_encoder=None, + unet=transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_transformer=True, + target_lin_modules=["Transformer"], + ) + network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True) + network.force_to(DEVICE, DTYPE) + network._update_torch_multiplier() + network.is_active = True + network.train() + return network + + +def benchmark(results: list, label: str, do_quantize: bool, offload_percent, grad_checkpointing): + from toolkit.memory_management import MemoryManager + from toolkit.util.quantize import quantize, get_qtype + from optimum.quanto import freeze + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + network = None + model = build_model() + model.gradient_checkpointing = grad_checkpointing + model.to(DEVICE) + + if do_quantize: + # Quantize the linear weights to float8 on the GPU (torchao weight-only), + # exactly as a quantized base model is prepared before training. + quantize(model, weights=get_qtype(QTYPE)) + freeze(model) + + # Base model is frozen; only the LoRA trains. + model.requires_grad_(False) + + if offload_percent is None: + # Baseline: whole base model stays on the GPU. + model.to(DEVICE) + else: + # Offloading: managed linears stay pinned on CPU and bounce per step + # (float8 weights are dequantized on the GPU); unmanaged modules (norms) + # move to the GPU via the patched .to(). Attach BEFORE the LoRA so the + # LoRA wraps the bouncing forward. Layer sampling is seeded for repro. + import random + random.seed(0) + MemoryManager.attach(model, DEVICE, offload_percent=offload_percent) + model.to(DEVICE) + + # build_lora prints a banner per layer; mute it so the final table is clean. + with contextlib.redirect_stdout(io.StringIO()): + network = build_lora(model) + params = network.prepare_optimizer_params(LR, LR, LR) + opt = torch.optim.AdamW(params, lr=LR) + x = torch.randn(BATCH, SEQ, D_MODEL, device=DEVICE, dtype=DTYPE) + + try: + for _ in range(WARMUP): + opt.zero_grad(set_to_none=True) + model(x).sum().backward() + opt.step() + torch.cuda.synchronize() + + # Measure the steady-state TRAINING peak, not the one-time setup load. + # (Offload first parks the whole model on the GPU before bouncing it to + # CPU; counting that transient would hide the real per-step footprint.) + torch.cuda.reset_peak_memory_stats() + + t0 = time.perf_counter() + with RamMonitor() as ram: + for _ in range(ITERS): + opt.zero_grad(set_to_none=True) + model(x).sum().backward() + opt.step() + torch.cuda.synchronize() + dt = (time.perf_counter() - t0) / ITERS * 1000 + peak = torch.cuda.max_memory_allocated() / 1024**3 + ram_gb = ram.peak / 1024**3 + results.append({"label": label, "ms": dt, "peak": peak, "ram": ram_gb, "ok": True}) + except torch.cuda.OutOfMemoryError: + results.append({"label": label, "ms": float("inf"), "peak": float("inf"), "ram": float("inf"), "ok": False, "note": "OOM"}) + except Exception as e: + print(f" {label} failed: {type(e).__name__}: {e}", flush=True) + results.append({"label": label, "ms": float("inf"), "peak": float("inf"), "ram": float("inf"), "ok": False, "note": "ERR"}) + finally: + if offload_percent is not None: + MemoryManager.detach(model) + del opt, network, model + gc.collect() + torch.cuda.empty_cache() + + +def print_table(results: list): + headers = ["#", "Configuration", "Peak VRAM", "Peak RAM", "Time/step"] + rows = [] + for i, r in enumerate(results, 1): + if not r["ok"]: + rows.append([str(i), r["label"], r.get("note", "OOM"), "-", "-"]) + continue + rows.append([str(i), r["label"], f"{r['peak']:.2f} GB", f"{r['ram']:.2f} GB", f"{r['ms']:.1f} ms"]) + + widths = [max(len(str(row[c])) for row in [headers] + rows) for c in range(len(headers))] + + def fmt(row, sep=" │ "): + return sep.join(s.ljust(widths[c]) if c == 1 else s.rjust(widths[c]) for c, s in enumerate(row)) + + line_top = "─" * (sum(widths) + 3 * (len(widths) - 1)) + print() + print(line_top) + print(fmt(headers)) + print(line_top) + for row in rows: + print(fmt(row)) + print(line_top) + + +def run_one(idx: int): + """Run a single config and print its result as JSON. Invoked in a fresh + subprocess so peak RAM (and VRAM) are isolated — pinned-host and CUDA host + caches don't release between runs, so in-process RAM peaks would accumulate.""" + import json + + label, do_quantize, offload_percent, grad_checkpointing = RUNS[idx] + results: list = [] + benchmark(results, label, do_quantize, offload_percent, grad_checkpointing) + print("RESULT " + json.dumps(results[0]), flush=True) + + +def main(): + import json + import subprocess + + n_params = sum(p.numel() for p in build_model().parameters()) + print(f"Model: {N_LAYERS} blocks × d_model={D_MODEL} × d_ff={D_FF}") + print(f" {n_params/1e6:.1f}M params") + print(f"dtype: {str(DTYPE).replace('torch.', '')} (quant qtype: {QTYPE})") + print(f"Train: LoRA rank={LORA_RANK} on a frozen base") + print(f"Step: batch={BATCH}, seq={SEQ}") + print(f"Timing: {WARMUP} warmup + {ITERS} timed iters") + print(f"Configs: {len(RUNS)} (each in an isolated subprocess)") + + results: list = [] + for idx, run in enumerate(RUNS): + print(f" running {run[0]}...", flush=True) + proc = subprocess.run( + [sys.executable, __file__, "--run", str(idx)], + capture_output=True, text=True, + ) + line = next((ln for ln in proc.stdout.splitlines() if ln.startswith("RESULT ")), None) + if line is None: + print(f" {run[0]} produced no result:\n{proc.stdout}\n{proc.stderr}", flush=True) + results.append({"label": run[0], "ok": False, "note": "ERR"}) + continue + results.append(json.loads(line[len("RESULT "):])) + print_table(results) + + +if __name__ == "__main__": + if len(sys.argv) >= 3 and sys.argv[1] == "--run": + run_one(int(sys.argv[2])) + else: + main() diff --git a/ai-toolkit/toolkit/metadata.py b/ai-toolkit/toolkit/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5c36adae70feb2624a84a3b8dbe05f24ed60ed --- /dev/null +++ b/ai-toolkit/toolkit/metadata.py @@ -0,0 +1,88 @@ +import json +from collections import OrderedDict +from io import BytesIO + +import safetensors +from safetensors import safe_open + +from info import software_meta +from toolkit.train_tools import addnet_hash_legacy +from toolkit.train_tools import addnet_hash_safetensors + + +def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict: + # stringify the meta and reparse OrderedDict to replace [name] with name + meta_string = json.dumps(meta) + if name is not None: + meta_string = meta_string.replace("[name]", name) + save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict) + if add_software_info: + save_meta["software"] = software_meta + # safetensors can only be one level deep + for key, value in save_meta.items(): + # if not float, int, bool, or str, convert to json string + if not isinstance(value, str): + save_meta[key] = json.dumps(value) + # add the pt format + save_meta["format"] = "pt" + return save_meta + + +def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict: + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in meta.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(state_dict, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + meta["sshs_model_hash"] = model_hash + meta["sshs_legacy_hash"] = legacy_hash + return meta + + +def add_base_model_info_to_meta( + meta: OrderedDict, + base_model: str = None, + is_v1: bool = False, + is_v2: bool = False, + is_xl: bool = False, +) -> OrderedDict: + if base_model is not None: + meta['ss_base_model'] = base_model + elif is_v2: + meta['ss_v2'] = True + meta['ss_base_model_version'] = 'sd_2.1' + + elif is_xl: + meta['ss_base_model_version'] = 'sdxl_1.0' + else: + # default to v1.5 + meta['ss_base_model_version'] = 'sd_1.5' + return meta + + +def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: + parsed_meta = OrderedDict() + for key, value in meta.items(): + try: + parsed_meta[key] = json.loads(value) + except json.decoder.JSONDecodeError: + parsed_meta[key] = value + return parsed_meta + + +def load_metadata_from_safetensors(file_path: str) -> OrderedDict: + try: + with safe_open(file_path, framework="pt") as f: + metadata = f.metadata() + return parse_metadata_from_safetensors(metadata) + except Exception as e: + print(f"Error loading metadata from {file_path}: {e}") + return OrderedDict() diff --git a/ai-toolkit/toolkit/models/DoRA.py b/ai-toolkit/toolkit/models/DoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..653575e94e640ae2900230d1e3b36f8d3ea5f93e --- /dev/null +++ b/ai-toolkit/toolkit/models/DoRA.py @@ -0,0 +1,146 @@ +#based off https://github.com/catid/dora/blob/main/dora.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import TYPE_CHECKING, Union, List + +from optimum.quanto import QBytesTensor, QTensor + +from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + +class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): + # def __init__(self, d_in, d_out, rank=4, weight=None, bias=None): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + network: 'LoRASpecialNetwork' = None, + use_bias: bool = False, + **kwargs + ): + self.can_merge_in = False + """if alpha == 0 or None, alpha is rank (no scaling).""" + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + self.lora_name = lora_name + self.scalar = torch.tensor(1.0) + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ in CONV_MODULES: + raise NotImplementedError("Convolutional layers are not supported yet") + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + # self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える eng: treat as constant + + self.multiplier: Union[float, List[float]] = multiplier + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.is_checkpointing = False + + d_out = org_module.out_features + d_in = org_module.in_features + + std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float()) + # self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) # lora_A + # self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) # lora_B + self.lora_up = nn.Linear(self.lora_dim, d_out, bias=False) # lora_B + # self.lora_up.weight.data = torch.randn_like(self.lora_up.weight.data) * std_dev + self.lora_up.weight.data = torch.zeros_like(self.lora_up.weight.data) + # self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) + # self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) + self.lora_down = nn.Linear(d_in, self.lora_dim, bias=False) # lora_A + # self.lora_down.weight.data = torch.zeros_like(self.lora_down.weight.data) + self.lora_down.weight.data = torch.randn_like(self.lora_down.weight.data) * std_dev + + # m = Magnitude column-wise across output dimension + weight = self.get_orig_weight() + weight = weight.to(self.lora_up.weight.device, dtype=self.lora_up.weight.dtype) + lora_weight = self.lora_up.weight @ self.lora_down.weight + weight_norm = self._get_weight_norm(weight, lora_weight) + self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True) + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module + + def get_orig_weight(self): + weight = self.org_module[0].weight + if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): + return weight.dequantize().data.detach() + else: + return weight.data.detach() + + def get_orig_bias(self): + if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: + return self.org_module[0].bias.data.detach() + return None + + # def dora_forward(self, x, *args, **kwargs): + # lora = torch.matmul(self.lora_A, self.lora_B) + # adapted = self.get_orig_weight() + lora + # column_norm = adapted.norm(p=2, dim=0, keepdim=True) + # norm_adapted = adapted / column_norm + # calc_weights = self.magnitude * norm_adapted + # return F.linear(x, calc_weights, self.get_orig_bias()) + + def _get_weight_norm(self, weight, scaled_lora_weight) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = weight + scaled_lora_weight.to(weight.device) + weight_norm = torch.linalg.norm(weight, dim=1) + return weight_norm + + def apply_dora(self, x, scaled_lora_weight): + # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L192 + # lora weight is already scaled + + # magnitude = self.lora_magnitude_vector[active_adapter] + weight = self.get_orig_weight() + weight = weight.to(scaled_lora_weight.device, dtype=scaled_lora_weight.dtype) + weight_norm = self._get_weight_norm(weight, scaled_lora_weight) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + dora_weight = transpose(weight + scaled_lora_weight, False) + return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x.to(dora_weight.dtype), dora_weight) diff --git a/ai-toolkit/toolkit/models/FakeVAE.py b/ai-toolkit/toolkit/models/FakeVAE.py new file mode 100644 index 0000000000000000000000000000000000000000..90e3d507c25d31d354c8ec3c244447d02bf97a1e --- /dev/null +++ b/ai-toolkit/toolkit/models/FakeVAE.py @@ -0,0 +1,135 @@ +from diffusers import AutoencoderKL +from typing import Optional, Union +import torch +import torch.nn as nn +import numpy as np +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.autoencoders.vae import DecoderOutput + + +class Config: + in_channels = 3 + out_channels = 3 + down_block_types = ("1",) + up_block_types = ("1",) + block_out_channels = (1,) + latent_channels = 3 # usually 4 + norm_num_groups = 1 + sample_size = 512 + scaling_factor = 1.0 + # scaling_factor = 1.8 + shift_factor = 0 + + def __getitem__(cls, x): + return getattr(cls, x) + + +class FakeVAE(nn.Module): + def __init__(self, scaling_factor=1.0): + super().__init__() + self._dtype = torch.float32 + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.config = Config() + self.config.scaling_factor = scaling_factor + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + # mimic to from torch + def to(self, *args, **kwargs): + # pull out dtype and device if they exist + if "dtype" in kwargs: + self._dtype = kwargs["dtype"] + if "device" in kwargs: + self._device = kwargs["device"] + return super().to(*args, **kwargs) + + def enable_xformers_memory_efficient_attention(self): + pass + + # @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> AutoencoderKLOutput: + h = x + + # moments = self.quant_conv(h) + # posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (h,) + + class FakeDist: + def __init__(self, x): + self._sample = x + + def sample(self): + return self._sample + + return AutoencoderKLOutput(latent_dist=FakeDist(h)) + + def _decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + dec = z + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def _set_gradient_checkpointing(self, module, value=False): + pass + + def enable_tiling(self, use_tiling: bool = True): + pass + + def disable_tiling(self): + pass + + def enable_slicing(self): + pass + + def disable_slicing(self): + pass + + def set_use_memory_efficient_attention_xformers(self, value: bool = True): + pass + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + dec = sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/ai-toolkit/toolkit/models/LoRAFormer.py b/ai-toolkit/toolkit/models/LoRAFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..2097a56072bb237808b54e368f92d871eaca4c92 --- /dev/null +++ b/ai-toolkit/toolkit/models/LoRAFormer.py @@ -0,0 +1,264 @@ +import math +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler +import sys +from collections import OrderedDict + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class TransformerBlock(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.ReLU(), + nn.Linear(dim_feedforward, d_model) + ) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, x, cross_attn_input): + # Self-attention + attn_output, _ = self.self_attn(x, x, x) + x = self.norm1(x + attn_output) + + # Cross-attention + cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input) + x = self.norm2(x + cross_attn_output) + + # Feed-forward + ff_output = self.feed_forward(x) + x = self.norm3(x + ff_output) + + return x + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + +# Initialize the network +# num_blocks = 8 +# d_model = 1024 # Adjust as needed +# nhead = 16 # Adjust as needed +# dim_feedforward = 4096 # Adjust as needed +# latent_dim = 1695744 + +class LoRAFormer(torch.nn.Module): + def __init__( + self, + num_blocks, + d_model=1024, + nhead=16, + dim_feedforward=4096, + sd: 'StableDiffusion'=None, + ): + super(LoRAFormer, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + self.latent = nn.Parameter(torch.randn(1, output_size)) + self.latent_proj = nn.Linear(output_size, d_model) + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, nhead, dim_feedforward) + for _ in range(num_blocks) + ]) + self.final_proj = nn.Linear(d_model, output_size) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + } + diff --git a/ai-toolkit/toolkit/models/RRDB.py b/ai-toolkit/toolkit/models/RRDB.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a2ad955309d2d5bb7a19e61812e4a4a761fa2e --- /dev/null +++ b/ai-toolkit/toolkit/models/RRDB.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import functools +import math +import re +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import block as B + +esrgan_safetensors_keys = ['model.0.weight', 'model.0.bias', 'model.1.sub.0.RDB1.conv1.0.weight', + 'model.1.sub.0.RDB1.conv1.0.bias', 'model.1.sub.0.RDB1.conv2.0.weight', + 'model.1.sub.0.RDB1.conv2.0.bias', 'model.1.sub.0.RDB1.conv3.0.weight', + 'model.1.sub.0.RDB1.conv3.0.bias', 'model.1.sub.0.RDB1.conv4.0.weight', + 'model.1.sub.0.RDB1.conv4.0.bias', 'model.1.sub.0.RDB1.conv5.0.weight', + 'model.1.sub.0.RDB1.conv5.0.bias', 'model.1.sub.0.RDB2.conv1.0.weight', + 'model.1.sub.0.RDB2.conv1.0.bias', 'model.1.sub.0.RDB2.conv2.0.weight', + 'model.1.sub.0.RDB2.conv2.0.bias', 'model.1.sub.0.RDB2.conv3.0.weight', + 'model.1.sub.0.RDB2.conv3.0.bias', 'model.1.sub.0.RDB2.conv4.0.weight', + 'model.1.sub.0.RDB2.conv4.0.bias', 'model.1.sub.0.RDB2.conv5.0.weight', + 'model.1.sub.0.RDB2.conv5.0.bias', 'model.1.sub.0.RDB3.conv1.0.weight', + 'model.1.sub.0.RDB3.conv1.0.bias', 'model.1.sub.0.RDB3.conv2.0.weight', + 'model.1.sub.0.RDB3.conv2.0.bias', 'model.1.sub.0.RDB3.conv3.0.weight', + 'model.1.sub.0.RDB3.conv3.0.bias', 'model.1.sub.0.RDB3.conv4.0.weight', + 'model.1.sub.0.RDB3.conv4.0.bias', 'model.1.sub.0.RDB3.conv5.0.weight', + 'model.1.sub.0.RDB3.conv5.0.bias', 'model.1.sub.1.RDB1.conv1.0.weight', + 'model.1.sub.1.RDB1.conv1.0.bias', 'model.1.sub.1.RDB1.conv2.0.weight', + 'model.1.sub.1.RDB1.conv2.0.bias', 'model.1.sub.1.RDB1.conv3.0.weight', + 'model.1.sub.1.RDB1.conv3.0.bias', 'model.1.sub.1.RDB1.conv4.0.weight', + 'model.1.sub.1.RDB1.conv4.0.bias', 'model.1.sub.1.RDB1.conv5.0.weight', + 'model.1.sub.1.RDB1.conv5.0.bias', 'model.1.sub.1.RDB2.conv1.0.weight', + 'model.1.sub.1.RDB2.conv1.0.bias', 'model.1.sub.1.RDB2.conv2.0.weight', + 'model.1.sub.1.RDB2.conv2.0.bias', 'model.1.sub.1.RDB2.conv3.0.weight', + 'model.1.sub.1.RDB2.conv3.0.bias', 'model.1.sub.1.RDB2.conv4.0.weight', + 'model.1.sub.1.RDB2.conv4.0.bias', 'model.1.sub.1.RDB2.conv5.0.weight', + 'model.1.sub.1.RDB2.conv5.0.bias', 'model.1.sub.1.RDB3.conv1.0.weight', + 'model.1.sub.1.RDB3.conv1.0.bias', 'model.1.sub.1.RDB3.conv2.0.weight', + 'model.1.sub.1.RDB3.conv2.0.bias', 'model.1.sub.1.RDB3.conv3.0.weight', + 'model.1.sub.1.RDB3.conv3.0.bias', 'model.1.sub.1.RDB3.conv4.0.weight', + 'model.1.sub.1.RDB3.conv4.0.bias', 'model.1.sub.1.RDB3.conv5.0.weight', + 'model.1.sub.1.RDB3.conv5.0.bias', 'model.1.sub.2.RDB1.conv1.0.weight', + 'model.1.sub.2.RDB1.conv1.0.bias', 'model.1.sub.2.RDB1.conv2.0.weight', + 'model.1.sub.2.RDB1.conv2.0.bias', 'model.1.sub.2.RDB1.conv3.0.weight', + 'model.1.sub.2.RDB1.conv3.0.bias', 'model.1.sub.2.RDB1.conv4.0.weight', + 'model.1.sub.2.RDB1.conv4.0.bias', 'model.1.sub.2.RDB1.conv5.0.weight', + 'model.1.sub.2.RDB1.conv5.0.bias', 'model.1.sub.2.RDB2.conv1.0.weight', + 'model.1.sub.2.RDB2.conv1.0.bias', 'model.1.sub.2.RDB2.conv2.0.weight', + 'model.1.sub.2.RDB2.conv2.0.bias', 'model.1.sub.2.RDB2.conv3.0.weight', + 'model.1.sub.2.RDB2.conv3.0.bias', 'model.1.sub.2.RDB2.conv4.0.weight', + 'model.1.sub.2.RDB2.conv4.0.bias', 'model.1.sub.2.RDB2.conv5.0.weight', + 'model.1.sub.2.RDB2.conv5.0.bias', 'model.1.sub.2.RDB3.conv1.0.weight', + 'model.1.sub.2.RDB3.conv1.0.bias', 'model.1.sub.2.RDB3.conv2.0.weight', + 'model.1.sub.2.RDB3.conv2.0.bias', 'model.1.sub.2.RDB3.conv3.0.weight', + 'model.1.sub.2.RDB3.conv3.0.bias', 'model.1.sub.2.RDB3.conv4.0.weight', + 'model.1.sub.2.RDB3.conv4.0.bias', 'model.1.sub.2.RDB3.conv5.0.weight', + 'model.1.sub.2.RDB3.conv5.0.bias', 'model.1.sub.3.RDB1.conv1.0.weight', + 'model.1.sub.3.RDB1.conv1.0.bias', 'model.1.sub.3.RDB1.conv2.0.weight', + 'model.1.sub.3.RDB1.conv2.0.bias', 'model.1.sub.3.RDB1.conv3.0.weight', + 'model.1.sub.3.RDB1.conv3.0.bias', 'model.1.sub.3.RDB1.conv4.0.weight', + 'model.1.sub.3.RDB1.conv4.0.bias', 'model.1.sub.3.RDB1.conv5.0.weight', + 'model.1.sub.3.RDB1.conv5.0.bias', 'model.1.sub.3.RDB2.conv1.0.weight', + 'model.1.sub.3.RDB2.conv1.0.bias', 'model.1.sub.3.RDB2.conv2.0.weight', + 'model.1.sub.3.RDB2.conv2.0.bias', 'model.1.sub.3.RDB2.conv3.0.weight', + 'model.1.sub.3.RDB2.conv3.0.bias', 'model.1.sub.3.RDB2.conv4.0.weight', + 'model.1.sub.3.RDB2.conv4.0.bias', 'model.1.sub.3.RDB2.conv5.0.weight', + 'model.1.sub.3.RDB2.conv5.0.bias', 'model.1.sub.3.RDB3.conv1.0.weight', + 'model.1.sub.3.RDB3.conv1.0.bias', 'model.1.sub.3.RDB3.conv2.0.weight', + 'model.1.sub.3.RDB3.conv2.0.bias', 'model.1.sub.3.RDB3.conv3.0.weight', + 'model.1.sub.3.RDB3.conv3.0.bias', 'model.1.sub.3.RDB3.conv4.0.weight', + 'model.1.sub.3.RDB3.conv4.0.bias', 'model.1.sub.3.RDB3.conv5.0.weight', + 'model.1.sub.3.RDB3.conv5.0.bias', 'model.1.sub.4.RDB1.conv1.0.weight', + 'model.1.sub.4.RDB1.conv1.0.bias', 'model.1.sub.4.RDB1.conv2.0.weight', + 'model.1.sub.4.RDB1.conv2.0.bias', 'model.1.sub.4.RDB1.conv3.0.weight', + 'model.1.sub.4.RDB1.conv3.0.bias', 'model.1.sub.4.RDB1.conv4.0.weight', + 'model.1.sub.4.RDB1.conv4.0.bias', 'model.1.sub.4.RDB1.conv5.0.weight', + 'model.1.sub.4.RDB1.conv5.0.bias', 'model.1.sub.4.RDB2.conv1.0.weight', + 'model.1.sub.4.RDB2.conv1.0.bias', 'model.1.sub.4.RDB2.conv2.0.weight', + 'model.1.sub.4.RDB2.conv2.0.bias', 'model.1.sub.4.RDB2.conv3.0.weight', + 'model.1.sub.4.RDB2.conv3.0.bias', 'model.1.sub.4.RDB2.conv4.0.weight', + 'model.1.sub.4.RDB2.conv4.0.bias', 'model.1.sub.4.RDB2.conv5.0.weight', + 'model.1.sub.4.RDB2.conv5.0.bias', 'model.1.sub.4.RDB3.conv1.0.weight', + 'model.1.sub.4.RDB3.conv1.0.bias', 'model.1.sub.4.RDB3.conv2.0.weight', + 'model.1.sub.4.RDB3.conv2.0.bias', 'model.1.sub.4.RDB3.conv3.0.weight', + 'model.1.sub.4.RDB3.conv3.0.bias', 'model.1.sub.4.RDB3.conv4.0.weight', + 'model.1.sub.4.RDB3.conv4.0.bias', 'model.1.sub.4.RDB3.conv5.0.weight', + 'model.1.sub.4.RDB3.conv5.0.bias', 'model.1.sub.5.RDB1.conv1.0.weight', + 'model.1.sub.5.RDB1.conv1.0.bias', 'model.1.sub.5.RDB1.conv2.0.weight', + 'model.1.sub.5.RDB1.conv2.0.bias', 'model.1.sub.5.RDB1.conv3.0.weight', + 'model.1.sub.5.RDB1.conv3.0.bias', 'model.1.sub.5.RDB1.conv4.0.weight', + 'model.1.sub.5.RDB1.conv4.0.bias', 'model.1.sub.5.RDB1.conv5.0.weight', + 'model.1.sub.5.RDB1.conv5.0.bias', 'model.1.sub.5.RDB2.conv1.0.weight', + 'model.1.sub.5.RDB2.conv1.0.bias', 'model.1.sub.5.RDB2.conv2.0.weight', + 'model.1.sub.5.RDB2.conv2.0.bias', 'model.1.sub.5.RDB2.conv3.0.weight', + 'model.1.sub.5.RDB2.conv3.0.bias', 'model.1.sub.5.RDB2.conv4.0.weight', + 'model.1.sub.5.RDB2.conv4.0.bias', 'model.1.sub.5.RDB2.conv5.0.weight', + 'model.1.sub.5.RDB2.conv5.0.bias', 'model.1.sub.5.RDB3.conv1.0.weight', + 'model.1.sub.5.RDB3.conv1.0.bias', 'model.1.sub.5.RDB3.conv2.0.weight', + 'model.1.sub.5.RDB3.conv2.0.bias', 'model.1.sub.5.RDB3.conv3.0.weight', + 'model.1.sub.5.RDB3.conv3.0.bias', 'model.1.sub.5.RDB3.conv4.0.weight', + 'model.1.sub.5.RDB3.conv4.0.bias', 'model.1.sub.5.RDB3.conv5.0.weight', + 'model.1.sub.5.RDB3.conv5.0.bias', 'model.1.sub.6.RDB1.conv1.0.weight', + 'model.1.sub.6.RDB1.conv1.0.bias', 'model.1.sub.6.RDB1.conv2.0.weight', + 'model.1.sub.6.RDB1.conv2.0.bias', 'model.1.sub.6.RDB1.conv3.0.weight', + 'model.1.sub.6.RDB1.conv3.0.bias', 'model.1.sub.6.RDB1.conv4.0.weight', + 'model.1.sub.6.RDB1.conv4.0.bias', 'model.1.sub.6.RDB1.conv5.0.weight', + 'model.1.sub.6.RDB1.conv5.0.bias', 'model.1.sub.6.RDB2.conv1.0.weight', + 'model.1.sub.6.RDB2.conv1.0.bias', 'model.1.sub.6.RDB2.conv2.0.weight', + 'model.1.sub.6.RDB2.conv2.0.bias', 'model.1.sub.6.RDB2.conv3.0.weight', + 'model.1.sub.6.RDB2.conv3.0.bias', 'model.1.sub.6.RDB2.conv4.0.weight', + 'model.1.sub.6.RDB2.conv4.0.bias', 'model.1.sub.6.RDB2.conv5.0.weight', + 'model.1.sub.6.RDB2.conv5.0.bias', 'model.1.sub.6.RDB3.conv1.0.weight', + 'model.1.sub.6.RDB3.conv1.0.bias', 'model.1.sub.6.RDB3.conv2.0.weight', + 'model.1.sub.6.RDB3.conv2.0.bias', 'model.1.sub.6.RDB3.conv3.0.weight', + 'model.1.sub.6.RDB3.conv3.0.bias', 'model.1.sub.6.RDB3.conv4.0.weight', + 'model.1.sub.6.RDB3.conv4.0.bias', 'model.1.sub.6.RDB3.conv5.0.weight', + 'model.1.sub.6.RDB3.conv5.0.bias', 'model.1.sub.7.RDB1.conv1.0.weight', + 'model.1.sub.7.RDB1.conv1.0.bias', 'model.1.sub.7.RDB1.conv2.0.weight', + 'model.1.sub.7.RDB1.conv2.0.bias', 'model.1.sub.7.RDB1.conv3.0.weight', + 'model.1.sub.7.RDB1.conv3.0.bias', 'model.1.sub.7.RDB1.conv4.0.weight', + 'model.1.sub.7.RDB1.conv4.0.bias', 'model.1.sub.7.RDB1.conv5.0.weight', + 'model.1.sub.7.RDB1.conv5.0.bias', 'model.1.sub.7.RDB2.conv1.0.weight', + 'model.1.sub.7.RDB2.conv1.0.bias', 'model.1.sub.7.RDB2.conv2.0.weight', + 'model.1.sub.7.RDB2.conv2.0.bias', 'model.1.sub.7.RDB2.conv3.0.weight', + 'model.1.sub.7.RDB2.conv3.0.bias', 'model.1.sub.7.RDB2.conv4.0.weight', + 'model.1.sub.7.RDB2.conv4.0.bias', 'model.1.sub.7.RDB2.conv5.0.weight', + 'model.1.sub.7.RDB2.conv5.0.bias', 'model.1.sub.7.RDB3.conv1.0.weight', + 'model.1.sub.7.RDB3.conv1.0.bias', 'model.1.sub.7.RDB3.conv2.0.weight', + 'model.1.sub.7.RDB3.conv2.0.bias', 'model.1.sub.7.RDB3.conv3.0.weight', + 'model.1.sub.7.RDB3.conv3.0.bias', 'model.1.sub.7.RDB3.conv4.0.weight', + 'model.1.sub.7.RDB3.conv4.0.bias', 'model.1.sub.7.RDB3.conv5.0.weight', + 'model.1.sub.7.RDB3.conv5.0.bias', 'model.1.sub.8.RDB1.conv1.0.weight', + 'model.1.sub.8.RDB1.conv1.0.bias', 'model.1.sub.8.RDB1.conv2.0.weight', + 'model.1.sub.8.RDB1.conv2.0.bias', 'model.1.sub.8.RDB1.conv3.0.weight', + 'model.1.sub.8.RDB1.conv3.0.bias', 'model.1.sub.8.RDB1.conv4.0.weight', + 'model.1.sub.8.RDB1.conv4.0.bias', 'model.1.sub.8.RDB1.conv5.0.weight', + 'model.1.sub.8.RDB1.conv5.0.bias', 'model.1.sub.8.RDB2.conv1.0.weight', + 'model.1.sub.8.RDB2.conv1.0.bias', 'model.1.sub.8.RDB2.conv2.0.weight', + 'model.1.sub.8.RDB2.conv2.0.bias', 'model.1.sub.8.RDB2.conv3.0.weight', + 'model.1.sub.8.RDB2.conv3.0.bias', 'model.1.sub.8.RDB2.conv4.0.weight', + 'model.1.sub.8.RDB2.conv4.0.bias', 'model.1.sub.8.RDB2.conv5.0.weight', + 'model.1.sub.8.RDB2.conv5.0.bias', 'model.1.sub.8.RDB3.conv1.0.weight', + 'model.1.sub.8.RDB3.conv1.0.bias', 'model.1.sub.8.RDB3.conv2.0.weight', + 'model.1.sub.8.RDB3.conv2.0.bias', 'model.1.sub.8.RDB3.conv3.0.weight', + 'model.1.sub.8.RDB3.conv3.0.bias', 'model.1.sub.8.RDB3.conv4.0.weight', + 'model.1.sub.8.RDB3.conv4.0.bias', 'model.1.sub.8.RDB3.conv5.0.weight', + 'model.1.sub.8.RDB3.conv5.0.bias', 'model.1.sub.9.RDB1.conv1.0.weight', + 'model.1.sub.9.RDB1.conv1.0.bias', 'model.1.sub.9.RDB1.conv2.0.weight', + 'model.1.sub.9.RDB1.conv2.0.bias', 'model.1.sub.9.RDB1.conv3.0.weight', + 'model.1.sub.9.RDB1.conv3.0.bias', 'model.1.sub.9.RDB1.conv4.0.weight', + 'model.1.sub.9.RDB1.conv4.0.bias', 'model.1.sub.9.RDB1.conv5.0.weight', + 'model.1.sub.9.RDB1.conv5.0.bias', 'model.1.sub.9.RDB2.conv1.0.weight', + 'model.1.sub.9.RDB2.conv1.0.bias', 'model.1.sub.9.RDB2.conv2.0.weight', + 'model.1.sub.9.RDB2.conv2.0.bias', 'model.1.sub.9.RDB2.conv3.0.weight', + 'model.1.sub.9.RDB2.conv3.0.bias', 'model.1.sub.9.RDB2.conv4.0.weight', + 'model.1.sub.9.RDB2.conv4.0.bias', 'model.1.sub.9.RDB2.conv5.0.weight', + 'model.1.sub.9.RDB2.conv5.0.bias', 'model.1.sub.9.RDB3.conv1.0.weight', + 'model.1.sub.9.RDB3.conv1.0.bias', 'model.1.sub.9.RDB3.conv2.0.weight', + 'model.1.sub.9.RDB3.conv2.0.bias', 'model.1.sub.9.RDB3.conv3.0.weight', + 'model.1.sub.9.RDB3.conv3.0.bias', 'model.1.sub.9.RDB3.conv4.0.weight', + 'model.1.sub.9.RDB3.conv4.0.bias', 'model.1.sub.9.RDB3.conv5.0.weight', + 'model.1.sub.9.RDB3.conv5.0.bias', 'model.1.sub.10.RDB1.conv1.0.weight', + 'model.1.sub.10.RDB1.conv1.0.bias', 'model.1.sub.10.RDB1.conv2.0.weight', + 'model.1.sub.10.RDB1.conv2.0.bias', 'model.1.sub.10.RDB1.conv3.0.weight', + 'model.1.sub.10.RDB1.conv3.0.bias', 'model.1.sub.10.RDB1.conv4.0.weight', + 'model.1.sub.10.RDB1.conv4.0.bias', 'model.1.sub.10.RDB1.conv5.0.weight', + 'model.1.sub.10.RDB1.conv5.0.bias', 'model.1.sub.10.RDB2.conv1.0.weight', + 'model.1.sub.10.RDB2.conv1.0.bias', 'model.1.sub.10.RDB2.conv2.0.weight', + 'model.1.sub.10.RDB2.conv2.0.bias', 'model.1.sub.10.RDB2.conv3.0.weight', + 'model.1.sub.10.RDB2.conv3.0.bias', 'model.1.sub.10.RDB2.conv4.0.weight', + 'model.1.sub.10.RDB2.conv4.0.bias', 'model.1.sub.10.RDB2.conv5.0.weight', + 'model.1.sub.10.RDB2.conv5.0.bias', 'model.1.sub.10.RDB3.conv1.0.weight', + 'model.1.sub.10.RDB3.conv1.0.bias', 'model.1.sub.10.RDB3.conv2.0.weight', + 'model.1.sub.10.RDB3.conv2.0.bias', 'model.1.sub.10.RDB3.conv3.0.weight', + 'model.1.sub.10.RDB3.conv3.0.bias', 'model.1.sub.10.RDB3.conv4.0.weight', + 'model.1.sub.10.RDB3.conv4.0.bias', 'model.1.sub.10.RDB3.conv5.0.weight', + 'model.1.sub.10.RDB3.conv5.0.bias', 'model.1.sub.11.RDB1.conv1.0.weight', + 'model.1.sub.11.RDB1.conv1.0.bias', 'model.1.sub.11.RDB1.conv2.0.weight', + 'model.1.sub.11.RDB1.conv2.0.bias', 'model.1.sub.11.RDB1.conv3.0.weight', + 'model.1.sub.11.RDB1.conv3.0.bias', 'model.1.sub.11.RDB1.conv4.0.weight', + 'model.1.sub.11.RDB1.conv4.0.bias', 'model.1.sub.11.RDB1.conv5.0.weight', + 'model.1.sub.11.RDB1.conv5.0.bias', 'model.1.sub.11.RDB2.conv1.0.weight', + 'model.1.sub.11.RDB2.conv1.0.bias', 'model.1.sub.11.RDB2.conv2.0.weight', + 'model.1.sub.11.RDB2.conv2.0.bias', 'model.1.sub.11.RDB2.conv3.0.weight', + 'model.1.sub.11.RDB2.conv3.0.bias', 'model.1.sub.11.RDB2.conv4.0.weight', + 'model.1.sub.11.RDB2.conv4.0.bias', 'model.1.sub.11.RDB2.conv5.0.weight', + 'model.1.sub.11.RDB2.conv5.0.bias', 'model.1.sub.11.RDB3.conv1.0.weight', + 'model.1.sub.11.RDB3.conv1.0.bias', 'model.1.sub.11.RDB3.conv2.0.weight', + 'model.1.sub.11.RDB3.conv2.0.bias', 'model.1.sub.11.RDB3.conv3.0.weight', + 'model.1.sub.11.RDB3.conv3.0.bias', 'model.1.sub.11.RDB3.conv4.0.weight', + 'model.1.sub.11.RDB3.conv4.0.bias', 'model.1.sub.11.RDB3.conv5.0.weight', + 'model.1.sub.11.RDB3.conv5.0.bias', 'model.1.sub.12.RDB1.conv1.0.weight', + 'model.1.sub.12.RDB1.conv1.0.bias', 'model.1.sub.12.RDB1.conv2.0.weight', + 'model.1.sub.12.RDB1.conv2.0.bias', 'model.1.sub.12.RDB1.conv3.0.weight', + 'model.1.sub.12.RDB1.conv3.0.bias', 'model.1.sub.12.RDB1.conv4.0.weight', + 'model.1.sub.12.RDB1.conv4.0.bias', 'model.1.sub.12.RDB1.conv5.0.weight', + 'model.1.sub.12.RDB1.conv5.0.bias', 'model.1.sub.12.RDB2.conv1.0.weight', + 'model.1.sub.12.RDB2.conv1.0.bias', 'model.1.sub.12.RDB2.conv2.0.weight', + 'model.1.sub.12.RDB2.conv2.0.bias', 'model.1.sub.12.RDB2.conv3.0.weight', + 'model.1.sub.12.RDB2.conv3.0.bias', 'model.1.sub.12.RDB2.conv4.0.weight', + 'model.1.sub.12.RDB2.conv4.0.bias', 'model.1.sub.12.RDB2.conv5.0.weight', + 'model.1.sub.12.RDB2.conv5.0.bias', 'model.1.sub.12.RDB3.conv1.0.weight', + 'model.1.sub.12.RDB3.conv1.0.bias', 'model.1.sub.12.RDB3.conv2.0.weight', + 'model.1.sub.12.RDB3.conv2.0.bias', 'model.1.sub.12.RDB3.conv3.0.weight', + 'model.1.sub.12.RDB3.conv3.0.bias', 'model.1.sub.12.RDB3.conv4.0.weight', + 'model.1.sub.12.RDB3.conv4.0.bias', 'model.1.sub.12.RDB3.conv5.0.weight', + 'model.1.sub.12.RDB3.conv5.0.bias', 'model.1.sub.13.RDB1.conv1.0.weight', + 'model.1.sub.13.RDB1.conv1.0.bias', 'model.1.sub.13.RDB1.conv2.0.weight', + 'model.1.sub.13.RDB1.conv2.0.bias', 'model.1.sub.13.RDB1.conv3.0.weight', + 'model.1.sub.13.RDB1.conv3.0.bias', 'model.1.sub.13.RDB1.conv4.0.weight', + 'model.1.sub.13.RDB1.conv4.0.bias', 'model.1.sub.13.RDB1.conv5.0.weight', + 'model.1.sub.13.RDB1.conv5.0.bias', 'model.1.sub.13.RDB2.conv1.0.weight', + 'model.1.sub.13.RDB2.conv1.0.bias', 'model.1.sub.13.RDB2.conv2.0.weight', + 'model.1.sub.13.RDB2.conv2.0.bias', 'model.1.sub.13.RDB2.conv3.0.weight', + 'model.1.sub.13.RDB2.conv3.0.bias', 'model.1.sub.13.RDB2.conv4.0.weight', + 'model.1.sub.13.RDB2.conv4.0.bias', 'model.1.sub.13.RDB2.conv5.0.weight', + 'model.1.sub.13.RDB2.conv5.0.bias', 'model.1.sub.13.RDB3.conv1.0.weight', + 'model.1.sub.13.RDB3.conv1.0.bias', 'model.1.sub.13.RDB3.conv2.0.weight', + 'model.1.sub.13.RDB3.conv2.0.bias', 'model.1.sub.13.RDB3.conv3.0.weight', + 'model.1.sub.13.RDB3.conv3.0.bias', 'model.1.sub.13.RDB3.conv4.0.weight', + 'model.1.sub.13.RDB3.conv4.0.bias', 'model.1.sub.13.RDB3.conv5.0.weight', + 'model.1.sub.13.RDB3.conv5.0.bias', 'model.1.sub.14.RDB1.conv1.0.weight', + 'model.1.sub.14.RDB1.conv1.0.bias', 'model.1.sub.14.RDB1.conv2.0.weight', + 'model.1.sub.14.RDB1.conv2.0.bias', 'model.1.sub.14.RDB1.conv3.0.weight', + 'model.1.sub.14.RDB1.conv3.0.bias', 'model.1.sub.14.RDB1.conv4.0.weight', + 'model.1.sub.14.RDB1.conv4.0.bias', 'model.1.sub.14.RDB1.conv5.0.weight', + 'model.1.sub.14.RDB1.conv5.0.bias', 'model.1.sub.14.RDB2.conv1.0.weight', + 'model.1.sub.14.RDB2.conv1.0.bias', 'model.1.sub.14.RDB2.conv2.0.weight', + 'model.1.sub.14.RDB2.conv2.0.bias', 'model.1.sub.14.RDB2.conv3.0.weight', + 'model.1.sub.14.RDB2.conv3.0.bias', 'model.1.sub.14.RDB2.conv4.0.weight', + 'model.1.sub.14.RDB2.conv4.0.bias', 'model.1.sub.14.RDB2.conv5.0.weight', + 'model.1.sub.14.RDB2.conv5.0.bias', 'model.1.sub.14.RDB3.conv1.0.weight', + 'model.1.sub.14.RDB3.conv1.0.bias', 'model.1.sub.14.RDB3.conv2.0.weight', + 'model.1.sub.14.RDB3.conv2.0.bias', 'model.1.sub.14.RDB3.conv3.0.weight', + 'model.1.sub.14.RDB3.conv3.0.bias', 'model.1.sub.14.RDB3.conv4.0.weight', + 'model.1.sub.14.RDB3.conv4.0.bias', 'model.1.sub.14.RDB3.conv5.0.weight', + 'model.1.sub.14.RDB3.conv5.0.bias', 'model.1.sub.15.RDB1.conv1.0.weight', + 'model.1.sub.15.RDB1.conv1.0.bias', 'model.1.sub.15.RDB1.conv2.0.weight', + 'model.1.sub.15.RDB1.conv2.0.bias', 'model.1.sub.15.RDB1.conv3.0.weight', + 'model.1.sub.15.RDB1.conv3.0.bias', 'model.1.sub.15.RDB1.conv4.0.weight', + 'model.1.sub.15.RDB1.conv4.0.bias', 'model.1.sub.15.RDB1.conv5.0.weight', + 'model.1.sub.15.RDB1.conv5.0.bias', 'model.1.sub.15.RDB2.conv1.0.weight', + 'model.1.sub.15.RDB2.conv1.0.bias', 'model.1.sub.15.RDB2.conv2.0.weight', + 'model.1.sub.15.RDB2.conv2.0.bias', 'model.1.sub.15.RDB2.conv3.0.weight', + 'model.1.sub.15.RDB2.conv3.0.bias', 'model.1.sub.15.RDB2.conv4.0.weight', + 'model.1.sub.15.RDB2.conv4.0.bias', 'model.1.sub.15.RDB2.conv5.0.weight', + 'model.1.sub.15.RDB2.conv5.0.bias', 'model.1.sub.15.RDB3.conv1.0.weight', + 'model.1.sub.15.RDB3.conv1.0.bias', 'model.1.sub.15.RDB3.conv2.0.weight', + 'model.1.sub.15.RDB3.conv2.0.bias', 'model.1.sub.15.RDB3.conv3.0.weight', + 'model.1.sub.15.RDB3.conv3.0.bias', 'model.1.sub.15.RDB3.conv4.0.weight', + 'model.1.sub.15.RDB3.conv4.0.bias', 'model.1.sub.15.RDB3.conv5.0.weight', + 'model.1.sub.15.RDB3.conv5.0.bias', 'model.1.sub.16.RDB1.conv1.0.weight', + 'model.1.sub.16.RDB1.conv1.0.bias', 'model.1.sub.16.RDB1.conv2.0.weight', + 'model.1.sub.16.RDB1.conv2.0.bias', 'model.1.sub.16.RDB1.conv3.0.weight', + 'model.1.sub.16.RDB1.conv3.0.bias', 'model.1.sub.16.RDB1.conv4.0.weight', + 'model.1.sub.16.RDB1.conv4.0.bias', 'model.1.sub.16.RDB1.conv5.0.weight', + 'model.1.sub.16.RDB1.conv5.0.bias', 'model.1.sub.16.RDB2.conv1.0.weight', + 'model.1.sub.16.RDB2.conv1.0.bias', 'model.1.sub.16.RDB2.conv2.0.weight', + 'model.1.sub.16.RDB2.conv2.0.bias', 'model.1.sub.16.RDB2.conv3.0.weight', + 'model.1.sub.16.RDB2.conv3.0.bias', 'model.1.sub.16.RDB2.conv4.0.weight', + 'model.1.sub.16.RDB2.conv4.0.bias', 'model.1.sub.16.RDB2.conv5.0.weight', + 'model.1.sub.16.RDB2.conv5.0.bias', 'model.1.sub.16.RDB3.conv1.0.weight', + 'model.1.sub.16.RDB3.conv1.0.bias', 'model.1.sub.16.RDB3.conv2.0.weight', + 'model.1.sub.16.RDB3.conv2.0.bias', 'model.1.sub.16.RDB3.conv3.0.weight', + 'model.1.sub.16.RDB3.conv3.0.bias', 'model.1.sub.16.RDB3.conv4.0.weight', + 'model.1.sub.16.RDB3.conv4.0.bias', 'model.1.sub.16.RDB3.conv5.0.weight', + 'model.1.sub.16.RDB3.conv5.0.bias', 'model.1.sub.17.RDB1.conv1.0.weight', + 'model.1.sub.17.RDB1.conv1.0.bias', 'model.1.sub.17.RDB1.conv2.0.weight', + 'model.1.sub.17.RDB1.conv2.0.bias', 'model.1.sub.17.RDB1.conv3.0.weight', + 'model.1.sub.17.RDB1.conv3.0.bias', 'model.1.sub.17.RDB1.conv4.0.weight', + 'model.1.sub.17.RDB1.conv4.0.bias', 'model.1.sub.17.RDB1.conv5.0.weight', + 'model.1.sub.17.RDB1.conv5.0.bias', 'model.1.sub.17.RDB2.conv1.0.weight', + 'model.1.sub.17.RDB2.conv1.0.bias', 'model.1.sub.17.RDB2.conv2.0.weight', + 'model.1.sub.17.RDB2.conv2.0.bias', 'model.1.sub.17.RDB2.conv3.0.weight', + 'model.1.sub.17.RDB2.conv3.0.bias', 'model.1.sub.17.RDB2.conv4.0.weight', + 'model.1.sub.17.RDB2.conv4.0.bias', 'model.1.sub.17.RDB2.conv5.0.weight', + 'model.1.sub.17.RDB2.conv5.0.bias', 'model.1.sub.17.RDB3.conv1.0.weight', + 'model.1.sub.17.RDB3.conv1.0.bias', 'model.1.sub.17.RDB3.conv2.0.weight', + 'model.1.sub.17.RDB3.conv2.0.bias', 'model.1.sub.17.RDB3.conv3.0.weight', + 'model.1.sub.17.RDB3.conv3.0.bias', 'model.1.sub.17.RDB3.conv4.0.weight', + 'model.1.sub.17.RDB3.conv4.0.bias', 'model.1.sub.17.RDB3.conv5.0.weight', + 'model.1.sub.17.RDB3.conv5.0.bias', 'model.1.sub.18.RDB1.conv1.0.weight', + 'model.1.sub.18.RDB1.conv1.0.bias', 'model.1.sub.18.RDB1.conv2.0.weight', + 'model.1.sub.18.RDB1.conv2.0.bias', 'model.1.sub.18.RDB1.conv3.0.weight', + 'model.1.sub.18.RDB1.conv3.0.bias', 'model.1.sub.18.RDB1.conv4.0.weight', + 'model.1.sub.18.RDB1.conv4.0.bias', 'model.1.sub.18.RDB1.conv5.0.weight', + 'model.1.sub.18.RDB1.conv5.0.bias', 'model.1.sub.18.RDB2.conv1.0.weight', + 'model.1.sub.18.RDB2.conv1.0.bias', 'model.1.sub.18.RDB2.conv2.0.weight', + 'model.1.sub.18.RDB2.conv2.0.bias', 'model.1.sub.18.RDB2.conv3.0.weight', + 'model.1.sub.18.RDB2.conv3.0.bias', 'model.1.sub.18.RDB2.conv4.0.weight', + 'model.1.sub.18.RDB2.conv4.0.bias', 'model.1.sub.18.RDB2.conv5.0.weight', + 'model.1.sub.18.RDB2.conv5.0.bias', 'model.1.sub.18.RDB3.conv1.0.weight', + 'model.1.sub.18.RDB3.conv1.0.bias', 'model.1.sub.18.RDB3.conv2.0.weight', + 'model.1.sub.18.RDB3.conv2.0.bias', 'model.1.sub.18.RDB3.conv3.0.weight', + 'model.1.sub.18.RDB3.conv3.0.bias', 'model.1.sub.18.RDB3.conv4.0.weight', + 'model.1.sub.18.RDB3.conv4.0.bias', 'model.1.sub.18.RDB3.conv5.0.weight', + 'model.1.sub.18.RDB3.conv5.0.bias', 'model.1.sub.19.RDB1.conv1.0.weight', + 'model.1.sub.19.RDB1.conv1.0.bias', 'model.1.sub.19.RDB1.conv2.0.weight', + 'model.1.sub.19.RDB1.conv2.0.bias', 'model.1.sub.19.RDB1.conv3.0.weight', + 'model.1.sub.19.RDB1.conv3.0.bias', 'model.1.sub.19.RDB1.conv4.0.weight', + 'model.1.sub.19.RDB1.conv4.0.bias', 'model.1.sub.19.RDB1.conv5.0.weight', + 'model.1.sub.19.RDB1.conv5.0.bias', 'model.1.sub.19.RDB2.conv1.0.weight', + 'model.1.sub.19.RDB2.conv1.0.bias', 'model.1.sub.19.RDB2.conv2.0.weight', + 'model.1.sub.19.RDB2.conv2.0.bias', 'model.1.sub.19.RDB2.conv3.0.weight', + 'model.1.sub.19.RDB2.conv3.0.bias', 'model.1.sub.19.RDB2.conv4.0.weight', + 'model.1.sub.19.RDB2.conv4.0.bias', 'model.1.sub.19.RDB2.conv5.0.weight', + 'model.1.sub.19.RDB2.conv5.0.bias', 'model.1.sub.19.RDB3.conv1.0.weight', + 'model.1.sub.19.RDB3.conv1.0.bias', 'model.1.sub.19.RDB3.conv2.0.weight', + 'model.1.sub.19.RDB3.conv2.0.bias', 'model.1.sub.19.RDB3.conv3.0.weight', + 'model.1.sub.19.RDB3.conv3.0.bias', 'model.1.sub.19.RDB3.conv4.0.weight', + 'model.1.sub.19.RDB3.conv4.0.bias', 'model.1.sub.19.RDB3.conv5.0.weight', + 'model.1.sub.19.RDB3.conv5.0.bias', 'model.1.sub.20.RDB1.conv1.0.weight', + 'model.1.sub.20.RDB1.conv1.0.bias', 'model.1.sub.20.RDB1.conv2.0.weight', + 'model.1.sub.20.RDB1.conv2.0.bias', 'model.1.sub.20.RDB1.conv3.0.weight', + 'model.1.sub.20.RDB1.conv3.0.bias', 'model.1.sub.20.RDB1.conv4.0.weight', + 'model.1.sub.20.RDB1.conv4.0.bias', 'model.1.sub.20.RDB1.conv5.0.weight', + 'model.1.sub.20.RDB1.conv5.0.bias', 'model.1.sub.20.RDB2.conv1.0.weight', + 'model.1.sub.20.RDB2.conv1.0.bias', 'model.1.sub.20.RDB2.conv2.0.weight', + 'model.1.sub.20.RDB2.conv2.0.bias', 'model.1.sub.20.RDB2.conv3.0.weight', + 'model.1.sub.20.RDB2.conv3.0.bias', 'model.1.sub.20.RDB2.conv4.0.weight', + 'model.1.sub.20.RDB2.conv4.0.bias', 'model.1.sub.20.RDB2.conv5.0.weight', + 'model.1.sub.20.RDB2.conv5.0.bias', 'model.1.sub.20.RDB3.conv1.0.weight', + 'model.1.sub.20.RDB3.conv1.0.bias', 'model.1.sub.20.RDB3.conv2.0.weight', + 'model.1.sub.20.RDB3.conv2.0.bias', 'model.1.sub.20.RDB3.conv3.0.weight', + 'model.1.sub.20.RDB3.conv3.0.bias', 'model.1.sub.20.RDB3.conv4.0.weight', + 'model.1.sub.20.RDB3.conv4.0.bias', 'model.1.sub.20.RDB3.conv5.0.weight', + 'model.1.sub.20.RDB3.conv5.0.bias', 'model.1.sub.21.RDB1.conv1.0.weight', + 'model.1.sub.21.RDB1.conv1.0.bias', 'model.1.sub.21.RDB1.conv2.0.weight', + 'model.1.sub.21.RDB1.conv2.0.bias', 'model.1.sub.21.RDB1.conv3.0.weight', + 'model.1.sub.21.RDB1.conv3.0.bias', 'model.1.sub.21.RDB1.conv4.0.weight', + 'model.1.sub.21.RDB1.conv4.0.bias', 'model.1.sub.21.RDB1.conv5.0.weight', + 'model.1.sub.21.RDB1.conv5.0.bias', 'model.1.sub.21.RDB2.conv1.0.weight', + 'model.1.sub.21.RDB2.conv1.0.bias', 'model.1.sub.21.RDB2.conv2.0.weight', + 'model.1.sub.21.RDB2.conv2.0.bias', 'model.1.sub.21.RDB2.conv3.0.weight', + 'model.1.sub.21.RDB2.conv3.0.bias', 'model.1.sub.21.RDB2.conv4.0.weight', + 'model.1.sub.21.RDB2.conv4.0.bias', 'model.1.sub.21.RDB2.conv5.0.weight', + 'model.1.sub.21.RDB2.conv5.0.bias', 'model.1.sub.21.RDB3.conv1.0.weight', + 'model.1.sub.21.RDB3.conv1.0.bias', 'model.1.sub.21.RDB3.conv2.0.weight', + 'model.1.sub.21.RDB3.conv2.0.bias', 'model.1.sub.21.RDB3.conv3.0.weight', + 'model.1.sub.21.RDB3.conv3.0.bias', 'model.1.sub.21.RDB3.conv4.0.weight', + 'model.1.sub.21.RDB3.conv4.0.bias', 'model.1.sub.21.RDB3.conv5.0.weight', + 'model.1.sub.21.RDB3.conv5.0.bias', 'model.1.sub.22.RDB1.conv1.0.weight', + 'model.1.sub.22.RDB1.conv1.0.bias', 'model.1.sub.22.RDB1.conv2.0.weight', + 'model.1.sub.22.RDB1.conv2.0.bias', 'model.1.sub.22.RDB1.conv3.0.weight', + 'model.1.sub.22.RDB1.conv3.0.bias', 'model.1.sub.22.RDB1.conv4.0.weight', + 'model.1.sub.22.RDB1.conv4.0.bias', 'model.1.sub.22.RDB1.conv5.0.weight', + 'model.1.sub.22.RDB1.conv5.0.bias', 'model.1.sub.22.RDB2.conv1.0.weight', + 'model.1.sub.22.RDB2.conv1.0.bias', 'model.1.sub.22.RDB2.conv2.0.weight', + 'model.1.sub.22.RDB2.conv2.0.bias', 'model.1.sub.22.RDB2.conv3.0.weight', + 'model.1.sub.22.RDB2.conv3.0.bias', 'model.1.sub.22.RDB2.conv4.0.weight', + 'model.1.sub.22.RDB2.conv4.0.bias', 'model.1.sub.22.RDB2.conv5.0.weight', + 'model.1.sub.22.RDB2.conv5.0.bias', 'model.1.sub.22.RDB3.conv1.0.weight', + 'model.1.sub.22.RDB3.conv1.0.bias', 'model.1.sub.22.RDB3.conv2.0.weight', + 'model.1.sub.22.RDB3.conv2.0.bias', 'model.1.sub.22.RDB3.conv3.0.weight', + 'model.1.sub.22.RDB3.conv3.0.bias', 'model.1.sub.22.RDB3.conv4.0.weight', + 'model.1.sub.22.RDB3.conv4.0.bias', 'model.1.sub.22.RDB3.conv5.0.weight', + 'model.1.sub.22.RDB3.conv5.0.bias', 'model.1.sub.23.weight', 'model.1.sub.23.bias', + 'model.3.weight', 'model.3.bias', 'model.6.weight', 'model.6.bias', 'model.8.weight', + 'model.8.bias', 'model.10.weight', 'model.10.bias'] + + +# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py +# Which enhanced stuff that was already here +class RRDBNet(nn.Module): + def __init__( + self, + state_dict, + norm=None, + act: str = "leakyrelu", + upsampler: str = "upconv", + mode: B.ConvMode = "CNA", + ) -> None: + """ + ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. + By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, + and Chen Change Loy. + This is old-arch Residual in Residual Dense Block Network and is not + the newest revision that's available at github.com/xinntao/ESRGAN. + This is on purpose, the newest Network has severely limited the + potential use of the Network with no benefits. + This network supports model files from both new and old-arch. + Args: + norm: Normalization layer + act: Activation layer + upsampler: Upsample layer. upconv, pixel_shuffle + mode: Convolution mode + """ + super(RRDBNet, self).__init__() + self.model_arch = "ESRGAN" + self.sub_type = "SR" + + self.state = state_dict + self.norm = norm + self.act = act + self.upsampler = upsampler + self.mode = mode + + self.state_map = { + # currently supports old, new, and newer RRDBNet arch models + # ESRGAN, BSRGAN/RealSR, Real-ESRGAN + "model.0.weight": ("conv_first.weight",), + "model.0.bias": ("conv_first.bias",), + "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), + "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), + r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( + r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", + r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", + ), + } + if "params_ema" in self.state: + self.state = self.state["params_ema"] + # self.model_arch = "RealESRGAN" + self.num_blocks = self.get_num_blocks() + self.plus = any("conv1x1" in k for k in self.state.keys()) + if self.plus: + self.model_arch = "ESRGAN+" + + self.state = self.new_to_old_arch(self.state) + + self.key_arr = list(self.state.keys()) + + self.in_nc: int = self.state[self.key_arr[0]].shape[1] + self.out_nc: int = self.state[self.key_arr[-1]].shape[0] + + self.scale: int = self.get_scale() + self.num_filters: int = self.state[self.key_arr[0]].shape[0] + + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + + self.supports_fp16 = True + self.supports_bfp16 = True + self.min_size_restriction = None + + # Detect if pixelunshuffle was used (Real-ESRGAN) + if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( + self.in_nc / 4, + self.in_nc / 16, + ): + self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) + else: + self.shuffle_factor = None + + upsample_block = { + "upconv": B.upconv_block, + "pixel_shuffle": B.pixelshuffle_block, + }.get(self.upsampler) + if upsample_block is None: + raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") + + if self.scale == 3: + upsample_blocks = upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + upscale_factor=3, + act_type=self.act, + c2x2=c2x2, + ) + else: + upsample_blocks = [ + upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, + ) + for _ in range(int(math.log(self.scale, 2))) + ] + + self.model = B.sequential( + # fea conv + B.conv_block( + in_nc=self.in_nc, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + c2x2=c2x2, + ), + B.ShortcutBlock( + B.sequential( + # rrdb blocks + *[ + B.RRDB( + nf=self.num_filters, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=self.norm, + act_type=self.act, + mode="CNA", + plus=self.plus, + c2x2=c2x2, + ) + for _ in range(self.num_blocks) + ], + # lr conv + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=self.norm, + act_type=None, + mode=self.mode, + c2x2=c2x2, + ), + ) + ), + *upsample_blocks, + # hr_conv0 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=self.act, + c2x2=c2x2, + ), + # hr_conv1 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.out_nc, + kernel_size=3, + norm_type=None, + act_type=None, + c2x2=c2x2, + ), + ) + + # Adjust these properties for calculations outside of the model + if self.shuffle_factor: + self.in_nc //= self.shuffle_factor ** 2 + self.scale //= self.shuffle_factor + + self.load_state_dict(self.state, strict=False) + + def new_to_old_arch(self, state): + """Convert a new-arch model state dictionary to an old-arch dictionary.""" + if "params_ema" in state: + state = state["params_ema"] + + if "conv_first.weight" not in state: + # model is already old arch, this is a loose check, but should be sufficient + return state + + # add nb to state keys + for kind in ("weight", "bias"): + self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ + f"model.1.sub./NB/.{kind}" + ] + del self.state_map[f"model.1.sub./NB/.{kind}"] + + old_state = OrderedDict() + for old_key, new_keys in self.state_map.items(): + for new_key in new_keys: + if r"\1" in old_key: + for k, v in state.items(): + sub = re.sub(new_key, old_key, k) + if sub != k: + old_state[sub] = v + else: + if new_key in state: + old_state[old_key] = state[new_key] + + # upconv layers + max_upconv = 0 + for key in state.keys(): + match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) + if match is not None: + _, key_num, key_type = match.groups() + old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] + max_upconv = max(max_upconv, int(key_num) * 3) + + # final layers + for key in state.keys(): + if key in ("HRconv.weight", "conv_hr.weight"): + old_state[f"model.{max_upconv + 2}.weight"] = state[key] + elif key in ("HRconv.bias", "conv_hr.bias"): + old_state[f"model.{max_upconv + 2}.bias"] = state[key] + elif key in ("conv_last.weight",): + old_state[f"model.{max_upconv + 4}.weight"] = state[key] + elif key in ("conv_last.bias",): + old_state[f"model.{max_upconv + 4}.bias"] = state[key] + + # Sort by first numeric value of each layer + def compare(item1, item2): + parts1 = item1.split(".") + parts2 = item2.split(".") + int1 = int(parts1[1]) + int2 = int(parts2[1]) + return int1 - int2 + + sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) + + # Rebuild the output dict in the right order + out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) + + return out_dict + + def get_scale(self, min_part: int = 6) -> int: + n = 0 + for part in list(self.state): + parts = part.split(".")[1:] + if len(parts) == 2: + part_num = int(parts[0]) + if part_num > min_part and parts[1] == "weight": + n += 1 + return 2 ** n + + def get_num_blocks(self) -> int: + nbs = [] + state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( + r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", + ) + for state_key in state_keys: + for k in self.state: + m = re.search(state_key, k) + if m: + nbs.append(int(m.group(1))) + if nbs: + break + return max(*nbs) + 1 + + def forward(self, x): + if self.shuffle_factor: + _, _, h, w = x.size() + mod_pad_h = ( + self.shuffle_factor - h % self.shuffle_factor + ) % self.shuffle_factor + mod_pad_w = ( + self.shuffle_factor - w % self.shuffle_factor + ) % self.shuffle_factor + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") + x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) + x = self.model(x) + return x[:, :, : h * self.scale, : w * self.scale] + return self.model(x) diff --git a/ai-toolkit/toolkit/models/auraflow.py b/ai-toolkit/toolkit/models/auraflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e2539bda489ccc1975f42b9c9a027076f8fdfc74 --- /dev/null +++ b/ai-toolkit/toolkit/models/auraflow.py @@ -0,0 +1,127 @@ +import math +from functools import partial + +from torch import nn +import torch + + +class AuraFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + try: + return latent + self.pos_embed + except RuntimeError: + raise RuntimeError( + f"Positional embeddings are too small for the number of patches. " + f"Please increase `pos_embed_max_size` to at least {self.num_patches}." + ) + + +# comfy +# def apply_pos_embeds(self, x, h, w): +# h = (h + 1) // self.patch_size +# w = (w + 1) // self.patch_size +# max_dim = max(h, w) +# +# cur_dim = self.h_max +# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) +# +# if max_dim > cur_dim: +# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, +# -1) +# cur_dim = max_dim +# +# from_h = (cur_dim - h) // 2 +# from_w = (cur_dim - w) // 2 +# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w] +# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) + + # def patchify(self, x): + # B, C, H, W = x.size() + # pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + # pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + # + # x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + # x = x.view( + # B, + # C, + # (H + 1) // self.patch_size, + # self.patch_size, + # (W + 1) // self.patch_size, + # self.patch_size, + # ) + # x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + # return x + +def patch_auraflow_pos_embed(pos_embed): + # we need to hijack the forward and replace with a custom one. Self is the model + def new_forward(self, latent): + batch_size, num_channels, height, width = latent.size() + + # add padding to the latent to make it match pos_embed + latent_size = height * width * num_channels / 16 # todo check where 16 comes from? + pos_embed_size = self.pos_embed.shape[1] + if latent_size < pos_embed_size: + total_padding = int(pos_embed_size - math.floor(latent_size)) + total_padding = total_padding // 16 + pad_height = total_padding // 2 + pad_width = total_padding - pad_height + # mirror padding on the right side + padding = (0, pad_width, 0, pad_height) + latent = torch.nn.functional.pad(latent, padding, mode='reflect') + elif latent_size > pos_embed_size: + amount_to_remove = latent_size - pos_embed_size + latent = latent[:, :, :-amount_to_remove] + + batch_size, num_channels, height, width = latent.size() + + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + try: + return latent + self.pos_embed + except RuntimeError: + raise RuntimeError( + f"Positional embeddings are too small for the number of patches. " + f"Please increase `pos_embed_max_size` to at least {self.num_patches}." + ) + + pos_embed.forward = partial(new_forward, pos_embed) diff --git a/ai-toolkit/toolkit/models/autoencoder_tiny_with_pooled_exits.py b/ai-toolkit/toolkit/models/autoencoder_tiny_with_pooled_exits.py new file mode 100644 index 0000000000000000000000000000000000000000..5771955de1223193fc8428a60609691599d3b262 --- /dev/null +++ b/ai-toolkit/toolkit/models/autoencoder_tiny_with_pooled_exits.py @@ -0,0 +1,187 @@ +from typing import Optional, Tuple, Union +from diffusers import AutoencoderTiny +from diffusers.models.autoencoders.vae import ( + EncoderTiny, + get_activation, + AutoencoderTinyBlock, + DecoderOutput +) +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.configuration_utils import register_to_config +import torch +import torch.nn as nn + +class DecoderTinyWithPooledExits(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + upsample_fn: str, + ): + super().__init__() + layers = [] + self.ordered_layers = [] + l = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) + self.ordered_layers.append(l) + layers.append(l) + l = get_activation(act_fn) + self.ordered_layers.append(l) + layers.append(l) + + pooled_exits = [] + + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + l = AutoencoderTinyBlock(num_channels, num_channels, act_fn) + layers.append(l) + self.ordered_layers.append(l) + + if not is_final_block: + l = nn.Upsample( + scale_factor=upsampling_scaling_factor, mode=upsample_fn + ) + layers.append(l) + self.ordered_layers.append(l) + + conv_out_channel = num_channels if not is_final_block else out_channels + l = nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + layers.append(l) + self.ordered_layers.append(l) + + if not is_final_block: + p = nn.Conv2d( + conv_out_channel, + out_channels=3, + kernel_size=3, + padding=1, + bias=True, + ) + p._is_pooled_exit = True + pooled_exits.append(p) + self.ordered_layers.append(p) + + self.layers = nn.ModuleList(layers) + self.pooled_exits = nn.ModuleList(pooled_exits) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor, pooled_outputs=False) -> torch.Tensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + pooled_output_list = [] + + for layer in self.ordered_layers: + # see if is pooled exit + try: + if hasattr(layer, '_is_pooled_exit') and layer._is_pooled_exit: + if pooled_outputs: + pooled_output = layer(x) + pooled_output_list.append(pooled_output) + else: + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(layer, x) + else: + x = layer(x) + except RuntimeError as e: + raise e + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + x = x.mul(2).sub(1) + + if pooled_outputs: + return x, pooled_output_list + return x + + +class AutoencoderTinyWithPooledExits(AutoencoderTiny): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + act_fn: str = "relu", + upsample_fn: str = "nearest", + latent_channels: int = 4, + upsampling_scaling_factor: int = 2, + num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), + num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), + latent_magnitude: int = 3, + latent_shift: float = 0.5, + force_upcast: bool = False, + scaling_factor: float = 1.0, + shift_factor: float = 0.0, + ): + super(AutoencoderTiny, self).__init__() + + if len(encoder_block_out_channels) != len(num_encoder_blocks): + raise ValueError( + "`encoder_block_out_channels` should have the same length as `num_encoder_blocks`." + ) + if len(decoder_block_out_channels) != len(num_decoder_blocks): + raise ValueError( + "`decoder_block_out_channels` should have the same length as `num_decoder_blocks`." + ) + + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ) + + self.decoder = DecoderTinyWithPooledExits( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + upsample_fn=upsample_fn, + ) + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2**out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = ( + self.tile_sample_min_size // self.spatial_scale_factor + ) + + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + + @apply_forward_hook + def decode_with_pooled_exits( + self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = False + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + output, pooled_outputs = self.decoder(x, pooled_outputs=True) + + if not return_dict: + return (output, pooled_outputs) + + return DecoderOutput(sample=output) diff --git a/ai-toolkit/toolkit/models/base_model.py b/ai-toolkit/toolkit/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4caf3ae7a3f3c72bf4ac0cd8f9e1ad3966b7f341 --- /dev/null +++ b/ai-toolkit/toolkit/models/base_model.py @@ -0,0 +1,1614 @@ +import copy +import gc +import inspect +import json +import random +import shutil +import typing +from typing import Optional, Union, List, Literal +import os +from collections import OrderedDict +import copy +import yaml +from PIL import Image +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from torch.nn import Parameter +from tqdm import tqdm +from torchvision.transforms import Resize, transforms + +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.ip_adapter import IPAdapter +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +from toolkit.models.decorator import Decorator +from toolkit.paths import KEYMAPS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.sd_device_states_presets import empty_preset +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import torch +from toolkit.pipelines import CustomStableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel +import diffusers +from diffusers import \ + AutoencoderKL, \ + UNet2DConditionModel +from diffusers import PixArtAlphaPipeline +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from torchvision.transforms import functional as TF + +from toolkit.accelerator import get_accelerator, unwrap_model +from typing import TYPE_CHECKING +from toolkit.print import print_acc +from toolkit.basic import flush + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) + +SD_PREFIX_VAE = "vae" +SD_PREFIX_UNET = "unet" +SD_PREFIX_REFINER_UNET = "refiner_unet" +SD_PREFIX_TEXT_ENCODER = "te" + +SD_PREFIX_TEXT_ENCODER1 = "te0" +SD_PREFIX_TEXT_ENCODER2 = "te1" + +# prefixed diffusers keys +DO_NOT_TRAIN_WEIGHTS = [ + "unet_time_embedding.linear_1.bias", + "unet_time_embedding.linear_1.weight", + "unet_time_embedding.linear_2.bias", + "unet_time_embedding.linear_2.weight", + "refiner_unet_time_embedding.linear_1.bias", + "refiner_unet_time_embedding.linear_1.weight", + "refiner_unet_time_embedding.linear_2.bias", + "refiner_unet_time_embedding.linear_2.weight", +] + +DeviceStatePreset = Literal['cache_latents', 'generate'] + + +class BlankNetwork: + + def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_merged_in = False + self.can_merge_in = False + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + def train(self): + pass + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + + +class BaseModel: + # override these in child classes + arch = None + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + self.accelerator = get_accelerator() + self.custom_pipeline = custom_pipeline + self.device = device + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + self.device_torch = torch.device(device) + + self.vae_device_torch = torch.device(device) + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = torch.device(device) + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + + self.device_state = None + + self.pipeline: Union[None, 'StableDiffusionPipeline', + 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.model: Union[None, 'Transformer2DModel', 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', + List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler + + self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + self.is_loaded = False + + # to hold network if there is one + self.network = None + self.adapter: Union['ControlNetModel', 'T2IAdapter', + 'IPAdapter', 'ReferenceAdapter', None] = None + self.decorator: Union[Decorator, None] = None + self.arch: ModelArch = model_config.arch + + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + + self.config_file = None + + self.is_flow_matching = False + + self.quantize_device = self.device_torch + self.low_vram = self.model_config.low_vram + + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + self._after_sample_img_hooks = [] + self._status_update_hooks = [] + self.is_transformer = False + + self.sample_prompts_cache = None + + self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None + self.is_multistage = False + # a list of multistage boundaries starting with train step 1000 to first idx + self.multistage_boundaries: List[float] = [0.0] + # a list of trainable multistage boundaries + self.trainable_multistage_boundaries: List[int] = [0] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False + # defines if the model supports model paths. Only some will + self.supports_model_paths = False + + # use new lokr format (default false for old models for backwards compatibility) + self.use_old_lokr_format = True + + # when padding to make batch size work, which side padding to use, right or left + # some llms need left side padding, others need right side + self.te_padding_side = "right" + + # can be used on models to invalidate cache if things change. + self.latent_space_version = None + + # if a mask is passed, do the loss with the mask. May be set false for models that use a mask for other reasons. + self.do_masked_loss = True + + # properties for old arch for backwards compatibility + @property + def unet(self): + return self.model + + # set unet to model + @unet.setter + def unet(self, value): + self.model = value + + @property + def transformer(self): + return self.model + + @transformer.setter + def transformer(self, value): + self.model = value + + @property + def unet_unwrapped(self): + return unwrap_model(self.model) + + @property + def model_unwrapped(self): + return unwrap_model(self.model) + + @property + def is_xl(self): + return self.arch == 'sdxl' + + @property + def is_v2(self): + return self.arch == 'sd2' + + @property + def is_ssd(self): + return self.arch == 'ssd' + + @property + def is_v3(self): + return self.arch == 'sd3' + + @property + def is_vega(self): + return self.arch == 'vega' + + @property + def is_pixart(self): + return self.arch == 'pixart' + + @property + def is_auraflow(self): + return self.arch == 'auraflow' + + @property + def is_flux(self): + return self.arch == 'flux' + + @property + def is_lumina2(self): + return self.arch == 'lumina2' + + @property + def text_embedding_space_version(self): + return self.arch + + def get_bucket_divisibility(self): + if self.vae is None: + return 8 + try: + divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) + except: + # if we have a custom vae, it might not have this + divisibility = 8 + + # flux packs this again, + if self.is_flux: + divisibility = divisibility * 2 + return divisibility + + # these must be implemented in child classes + def load_model(self): + # override this in child classes + raise NotImplementedError( + "load_model must be implemented in child classes") + + def get_generation_pipeline(self): + # override this in child classes + raise NotImplementedError( + "get_generation_pipeline must be implemented in child classes") + + def generate_single_image( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # override this in child classes + raise NotImplementedError( + "generate_single_image must be implemented in child classes") + + def get_noise_prediction( + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + raise NotImplementedError( + "get_noise_prediction must be implemented in child classes") + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + raise NotImplementedError( + "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") + + def save_model(self, output_path, meta, save_dtype): + # todo handle dtype without overloading anything (vram, cpu, etc) + unwrap_model(self.pipeline).save_pretrained( + save_directory=output_path, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + # end must be implemented in child classes + + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + elif self.text_encoder is not None: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + elif self.text_encoder is not None: + self.text_encoder.eval() + + def _after_sample_image(self, img_num, total_imgs): + # process all hooks + for hook in self._after_sample_img_hooks: + hook(img_num, total_imgs) + + def add_after_sample_image_hook(self, func): + self._after_sample_img_hooks.append(func) + + def _status_update(self, status: str): + for hook in self._status_update_hooks: + hook(status) + + def print_and_status_update(self, status: str): + print_acc(status) + self._status_update(status) + + def add_status_update_hook(self, func): + self._status_update_hooks.append(func) + + @torch.no_grad() + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + pipeline: Union[None, StableDiffusionPipeline, + StableDiffusionXLPipeline] = None, + ): + network = self.network + merge_multiplier = 1.0 + flush() + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print_acc("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to( + self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print_acc("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + + if network is not None: + network = unwrap_model(self.network) + network.eval() + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set( + [x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1 and network.can_merge_in: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) + else: + network = BlankNetwork() + + self.save_device_state() + self.set_device_state_preset('generate') + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + if pipeline is None: + pipeline = self.get_generation_pipeline() + try: + pipeline.set_progress_bar_config(disable=True) + except: + pass + + start_multiplier = 1.0 + if network is not None: + start_multiplier = network.multiplier + + # pipeline.to(self.device_torch) + + with network: + with torch.no_grad(): + if network is not None: + assert network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Samples", leave=False): + gen_config = image_configs[i] + + extra = {} + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path) + if ".inpaint." not in gen_config.adapter_image_path: + validation_image = validation_image.convert("RGB") + else: + # make sure it has an alpha + if validation_image.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize( + (gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize( + (gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['control_image'] = validation_image + extra['control_image_idx'] = gen_config.ctrl_idx + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if network is not None: + network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + if self.sample_prompts_cache is not None: + conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) + unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) + else: + ctrl_img = None + has_control_images = False + if gen_config.ctrl_img is not None or gen_config.ctrl_img_1 is not None or gen_config.ctrl_img_2 is not None or gen_config.ctrl_img_3 is not None: + has_control_images = True + # load the control image if out model uses it in text encoding + if has_control_images and self.encode_control_in_text_embeddings: + ctrl_img_list = [] + + if gen_config.ctrl_img is not None: + ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt( + gen_config.prompt, + gen_config.prompt_2, + force_all=True, + control_images=ctrl_img + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, + gen_config.negative_prompt_2, + force_all=True, + control_images=ctrl_img + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter( + conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter( + unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values( + extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like( + extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError( + "Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + + img = self.generate_single_image( + pipeline, + gen_config, + conditional_embeds, + unconditional_embeds, + generator, + extra, + ) + + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) + flush() + + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.restore_device_state() + if network is not None: + network.train() + network.multiplier = start_multiplier + + self.unet.to(self.device_torch, dtype=self.torch_dtype) + if network.is_merged_in: + network.merge_out(merge_multiplier) + # self.tokenizer.to(original_device_dict['tokenizer']) + + # refuse loras + if self.model_config.assistant_lora_path is not None: + print_acc("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print_acc("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + flush() + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + ): + VAE_SCALE_FACTOR = 2 ** ( + len(self.vae.config['block_out_channels']) - 1) + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + num_channels = self.unet_unwrapped.config['in_channels'] + if self.is_flux: + # has 64 channels in for some reason + num_channels = 16 + noise = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ), + device=self.unet.device, + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def get_latent_noise_from_latents( + self, + latents: torch.Tensor, + noise_offset=0.0 + ): + noise = torch.randn_like(latents) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + **kwargs, + ) -> torch.FloatTensor: + original_samples_chunks = torch.chunk( + original_samples, original_samples.shape[0], dim=0) + noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + timesteps_chunks = [timesteps_chunks[0]] * \ + len(original_samples_chunks) + + noisy_latents_chunks = [] + + for idx in range(original_samples.shape[0]): + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) + noisy_latents_chunks.append(noisy_latents) + + noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + return noisy_latents + + def predict_noise( + self, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, + guidance_scale=7.5, + guidance_rescale=0, + add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=False, + detach_unconditional=False, + rescale_cfg=None, + return_conditional_pred=False, + guidance_embedding_scale=1.0, + bypass_guidance_embedding=False, + batch: Union[None, 'DataLoaderBatchDTO'] = None, + **kwargs, + ): + conditional_pred = None + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError( + "Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + if isinstance(text_embeddings.text_embeds, list): + if len(text_embeddings.text_embeds[0].shape) == 2: + # handle list of embeddings + te_batch_size = len(text_embeddings.text_embeds) + else: + te_batch_size = text_embeddings.text_embeds[0].shape[0] + else: + te_batch_size = text_embeddings.text_embeds.shape[0] + if latents.shape[0] == te_batch_size: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != te_batch_size: + raise ValueError( + "Batch size of latents must be the same or half the batch size of text embeddings") + latents = latents.to(self.device_torch) + text_embeddings = text_embeddings.to(self.device_torch) + timestep = timestep.to(self.device_torch) + + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([ + item] * 2, dim=0) + + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([ + item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat( + [item] * 2, dim=0) + + def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk( + timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input( + mi_chunks[idx], timestep_chunks[idx]) + ) + return torch.cat(out_chunks, dim=0) + + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2, dim=0) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input( + latent_model_input, timestep) + + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat( + [timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2, dim=0) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + + # predict the noise residual + if self.unet.device != self.device_torch: + try: + self.unet.to(self.device_torch) + except Exception as e: + pass + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) + + # check if get_noise prediction has guidance_embedding_scale + # if it does not, we dont pass it + signatures = inspect.signature(self.get_noise_prediction).parameters + + if 'guidance_embedding_scale' in signatures: + kwargs['guidance_embedding_scale'] = guidance_embedding_scale + if 'bypass_guidance_embedding' in signatures: + kwargs['bypass_guidance_embedding'] = bypass_guidance_embedding + if 'batch' in signatures: + kwargs['batch'] = batch + + noise_pred = self.get_noise_prediction( + latent_model_input=latent_model_input, + timestep=timestep, + text_embeddings=text_embeddings, + **kwargs + ) + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean( + [1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std( + [1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean( + [1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if return_conditional_pred: + return noise_pred, conditional_pred + return noise_pred + + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler + # // sometimes they are on the wrong device, no idea why + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): + try: + noise_scheduler.betas = noise_scheduler.betas.to( + self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to( + self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to( + self.device_torch) + except Exception as e: + pass + + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk( + timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True + out_chunks.append( + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, + return_first_prediction=False, + **kwargs, + ): + timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + + first_prediction = None + + for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) + noise_pred, conditional_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, + return_conditional_pred=True, + **kwargs, + ) + # some schedulers need to run separately, so do that. (euler for example) + + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + + latents = self.step_scheduler(noise_pred, latents, timestep) + + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + \ + (bleed_latents * bleed_ratio) + + # only skip first scaling + is_input_scaled = False + + # return latents_steps + if return_first_prediction: + return latents, first_prediction + return latents + + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None, + dropout_prob=0.0, + control_images=None, + ) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] + # if control_images in the signature, pass it. This keep from breaking plugins + if self.encode_control_in_text_embeddings: + return self.get_prompt_embeds(prompt, control_images=control_images) + + return self.get_prompt_embeds(prompt) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + + VAE_SCALE_FACTOR = 2 ** ( + len(self.vae.config['block_out_channels']) - 1) + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + + images = torch.stack(image_list).to(device, dtype=dtype) + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) + latents = latents.to(device, dtype=dtype) + + return latents + + def encode_audio(self, audio_data_list): + # audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)} + raise NotImplementedError("Audio encoding not implemented for this model.") + + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device('cpu'): + self.vae.to(self.device) + latents = latents.to(device, dtype=dtype) + latents = ( + latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to( + self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) + + def state_dict(self, vae=True, text_encoder=True, unet=True): + state_dict = OrderedDict() + if vae: + for k, v in self.vae.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + state_dict[new_key] = v + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + for k, v in encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" + state_dict[new_key] = v + else: + for k, v in self.text_encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" + state_dict[new_key] = v + if unet: + for k, v in self.unet.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + state_dict[new_key] = v + return state_dict + + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: + named_params: OrderedDict[str, Parameter] = OrderedDict() + if vae: + for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): + named_params[name] = param + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): + named_params[name] = param + else: + for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): + named_params[name] = param + if unet: + if self.is_flux or self.is_lumina2 or self.is_transformer: + for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): + named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + if self.model_config.ignore_if_contains is not None: + # remove params that contain the ignore_if_contains from named params + for key in list(named_params.keys()): + if any([s in f"transformer.{key}" for s in self.model_config.ignore_if_contains]): + del named_params[key] + if self.model_config.only_if_contains is not None: + # remove params that do not contain the only_if_contains from named params + for key in list(named_params.keys()): + if not any([s in f"transformer.{key}" for s in self.model_config.only_if_contains]): + del named_params[key] + + if refiner: + for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): + named_params[name] = param + + # convert to state dict keys, jsut replace . with _ on keys + if state_dict_keys: + new_named_params = OrderedDict() + for k, v in named_params.items(): + # replace only the first . with an _ + new_key = k.replace('.', '_', 1) + new_named_params[new_key] = v + named_params = new_named_params + + return named_params + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + self.save_model( + output_path=output_file, + meta=meta, + save_dtype=save_dtype + ) + + def prepare_optimizer_params( + self, + unet=False, + text_encoder=False, + text_encoder_lr=None, + unet_lr=None, + refiner_lr=None, + refiner=False, + default_lr=1e-6, + ): + # todo maybe only get locon ones? + # not all items are saved, to make it match, we need to match out save mappings + # and not train anything not mapped. Also add learning rate + version = 'sd1' + if self.is_xl: + version = 'sdxl' + if self.is_v2: + version = 'sd2' + mapping_filename = f"stable_diffusion_{version}.json" + mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) + with open(mapping_path, 'r') as f: + mapping = json.load(f) + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + + trainable_parameters = [] + + # we use state dict to find params + + if unet: + named_params = self.named_parameters( + vae=False, unet=unet, text_encoder=False, state_dict_keys=True) + unet_lr = unet_lr if unet_lr is not None else default_lr + params = [] + for param in named_params.values(): + if param.requires_grad: + params.append(param) + + param_data = {"params": params, "lr": unet_lr} + trainable_parameters.append(param_data) + print_acc(f"Found {len(params)} trainable parameter in unet") + + if text_encoder: + named_params = self.named_parameters( + vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) + text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": text_encoder_lr} + trainable_parameters.append(param_data) + + print_acc( + f"Found {len(params)} trainable parameter in text encoder") + + if refiner: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) + refiner_lr = refiner_lr if refiner_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + diffusers_key = f"refiner_{diffusers_key}" + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": refiner_lr} + trainable_parameters.append(param_data) + + print_acc(f"Found {len(params)} trainable parameter in refiner") + + return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + unet_has_grad = self.get_model_has_grad() + + self.device_state = { + **empty_preset, + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + 'requires_grad': unet_has_grad, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + te_has_grad = self.get_te_has_grad() + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': te_has_grad + }) + else: + te_has_grad = self.get_te_has_grad() + + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + 'requires_grad': te_has_grad + } + if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.image_proj_model.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': adapter_device, + 'requires_grad': requires_grad, + } + + if self.refiner_unet is not None: + self.device_state['refiner_unet'] = { + 'training': self.refiner_unet.training, + 'device': self.refiner_unet.device, + 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_( + state['text_encoder'][i]['requires_grad']) + else: + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_( + state['text_encoder']['requires_grad']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_( + state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() + + if self.refiner_unet is not None: + self.refiner_unet.to(state['refiner_unet']['device']) + self.refiner_unet.requires_grad_( + state['refiner_unet']['requires_grad']) + if state['refiner_unet']['training']: + self.refiner_unet.train() + else: + self.refiner_unet.eval() + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', + 'text_encoder', 'adapter', 'refiner_unet'] + + state = copy.deepcopy(empty_preset) + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, + } + + if self.refiner_unet is not None: + state['refiner_unet'] = { + 'training': 'refiner_unet' in training_modules, + 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', + 'requires_grad': 'refiner_unet' in training_modules, + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + } + + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + + self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs) + + def convert_lora_weights_before_save(self, state_dict): + # can be overridden in child classes to convert weights before saving + return state_dict + + def convert_lora_weights_before_load(self, state_dict): + # can be overridden in child classes to convert weights before loading + return state_dict + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + # can be overridden in child classes to condition latents before noise prediction + return latents + + def get_transformer_block_names(self) -> Optional[List[str]]: + # override in child classes to get transformer block names for lora targeting + return None + + def get_quantization_exclude_modules(self) -> Optional[List[str]]: + # override in child classes to keep sensitive modules in full precision when + # quantizing. Returns fnmatch patterns matched against the transformer's module + # names (e.g. "model.x_embedder*"). + return None + + def get_base_model_version(self) -> str: + # override in child classes to get the base model version + return self.arch if self.arch is not None else 'unknown' + + def get_model_to_train(self): + # called to get model to attach LoRAs to. Can be overridden in child classes + return self.unet + + def scale_loss(self, loss): + # called to get the loss scaler for the model. Can be overridden in child classes + return loss diff --git a/ai-toolkit/toolkit/models/block.py b/ai-toolkit/toolkit/models/block.py new file mode 100644 index 0000000000000000000000000000000000000000..76356b5e3eb7c7d6dc4ed1629aac318c264111c5 --- /dev/null +++ b/ai-toolkit/toolkit/models/block.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from collections import OrderedDict + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn + + +#################### +# Basic blocks +#################### + + +def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1): + # helper selecting activation + # neg_slope: for leakyrelu and init of prelu + # n_prelu: for p_relu num_parameters + act_type = act_type.lower() + if act_type == "relu": + layer = nn.ReLU(inplace) + elif act_type == "leakyrelu": + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == "prelu": + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + else: + raise NotImplementedError( + "activation layer [{:s}] is not found".format(act_type) + ) + return layer + + +def norm(norm_type: str, nc: int): + # helper selecting normalization layer + norm_type = norm_type.lower() + if norm_type == "batch": + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == "instance": + layer = nn.InstanceNorm2d(nc, affine=False) + else: + raise NotImplementedError( + "normalization layer [{:s}] is not found".format(norm_type) + ) + return layer + + +def pad(pad_type: str, padding): + # helper selecting padding layer + # if padding is 'zero', do by conv layers + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == "reflect": + layer = nn.ReflectionPad2d(padding) + elif pad_type == "replicate": + layer = nn.ReplicationPad2d(padding) + else: + raise NotImplementedError( + "padding layer [{:s}] is not implemented".format(pad_type) + ) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ConcatBlock(nn.Module): + # Concat the output of a submodule to its input + def __init__(self, submodule): + super(ConcatBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = torch.cat((x, self.sub(x)), dim=1) + return output + + def __repr__(self): + tmpstr = "Identity .. \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlock(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlockSPSR(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlockSPSR, self).__init__() + self.sub = submodule + + def forward(self, x): + return x, self.sub + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +def sequential(*args): + # Flatten Sequential. It unwraps nn.Sequential. + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError("sequential does not support OrderedDict input.") + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +ConvMode = Literal["CNA", "NAC", "CNAC"] + + +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + +def conv_block( + in_nc: int, + out_nc: int, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type: str | None = "relu", + mode: ConvMode = "CNA", + c2x2=False, +): + """ + Conv layer with padding, normalization, activation + mode: CNA --> Conv -> Norm -> Act + NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) + """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None + padding = padding if pad_type == "zero" else 0 + + c = nn.Conv2d( + in_nc, + out_nc, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=groups, + ) + a = act(act_type) if act_type else None + if mode in ("CNA", "CNAC"): + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == "NAC": + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + # Important! + # input----ReLU(inplace)----Conv--+----output + # |________________________| + # inplace ReLU will modify the input, therefore wrong output + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + else: + assert False, f"Invalid conv mode {mode}" + + +#################### +# Useful blocks +#################### + + +class ResNetBlock(nn.Module): + """ + ResNet Block, 3-3 style + with extra residual scaling used in EDSR + (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) + """ + + def __init__( + self, + in_nc, + mid_nc, + out_nc, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_type="zero", + norm_type=None, + act_type="relu", + mode: ConvMode = "CNA", + res_scale=1, + ): + super(ResNetBlock, self).__init__() + conv0 = conv_block( + in_nc, + mid_nc, + kernel_size, + stride, + dilation, + groups, + bias, + pad_type, + norm_type, + act_type, + mode, + ) + if mode == "CNA": + act_type = None + if mode == "CNAC": # Residual path: |-CNAC-| + act_type = None + norm_type = None + conv1 = conv_block( + mid_nc, + out_nc, + kernel_size, + stride, + dilation, + groups, + bias, + pad_type, + norm_type, + act_type, + mode, + ) + # if in_nc != out_nc: + # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ + # None, None) + # print('Need a projecter in ResNetBlock.') + # else: + # self.project = lambda x:x + self.res = sequential(conv0, conv1) + self.res_scale = res_scale + + def forward(self, x): + res = self.res(x).mul(self.res_scale) + return x + res + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__( + self, + nf, + kernel_size=3, + gc=32, + stride=1, + bias: bool = True, + pad_type="zero", + norm_type=None, + act_type="leakyrelu", + mode: ConvMode = "CNA", + _convtype="Conv2D", + _spectral_norm=False, + plus=False, + c2x2=False, + ): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + self.RDB2 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + self.RDB3 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + style: 5 convs + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + {Rakotonirina} and A. {Rasoanaivo} + + Args: + nf (int): Channel number of intermediate features (num_feat). + gc (int): Channels for each growth (num_grow_ch: growth channel, + i.e. intermediate channels). + convtype (str): the type of convolution to use. Default: 'Conv2D' + gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new + trainable parameters) + plus (bool): enable the additional residual paths from ESRGAN+ + (adds trainable parameters) + """ + + def __init__( + self, + nf=64, + kernel_size=3, + gc=32, + stride=1, + bias: bool = True, + pad_type="zero", + norm_type=None, + act_type="leakyrelu", + mode: ConvMode = "CNA", + plus=False, + c2x2=False, + ): + super(ResidualDenseBlock_5C, self).__init__() + + ## + + self.conv1x1 = conv1x1(nf, gc) if plus else None + ## + + + self.conv1 = conv_block( + nf, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv2 = conv_block( + nf + gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv3 = conv_block( + nf + 2 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv4 = conv_block( + nf + 3 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + if mode == "CNA": + last_act = None + else: + last_act = act_type + self.conv5 = conv_block( + nf + 4 * gc, + nf, + 3, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=last_act, + mode=mode, + c2x2=c2x2, + ) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + # pylint: disable=not-callable + x2 = x2 + self.conv1x1(x) # + + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 # + + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#################### +# Upsampler +#################### + + +def pixelshuffle_block( + in_nc: int, + out_nc: int, + upscale_factor=2, + kernel_size=3, + stride=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type="relu", +): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block( + in_nc, + out_nc * (upscale_factor ** 2), + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=None, + act_type=None, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_block( + in_nc: int, + out_nc: int, + upscale_factor=2, + kernel_size=3, + stride=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type="relu", + mode="nearest", + c2x2=False, +): + # Up conv + # described in https://distill.pub/2016/deconv-checkerboard/ + # convert to float 16 if is bfloat16 + upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block( + in_nc, + out_nc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + c2x2=c2x2, + ) + return sequential(upsample, conv) diff --git a/ai-toolkit/toolkit/models/clip_fusion.py b/ai-toolkit/toolkit/models/clip_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f4346fd5ac3eae4c8d91e50df586acc8d4cd2fbe --- /dev/null +++ b/ai-toolkit/toolkit/models/clip_fusion.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn + +from toolkit.models.zipper_resampler import ContextualAlphaMask + + +# Conv1d MLP +# MLP that can alternately be used as a conv1d on dim 1 +class MLPC(nn.Module): + def __init__( + self, + in_dim, + out_dim, + hidden_dim, + do_conv=False, + use_residual=True + ): + super().__init__() + self.do_conv = do_conv + if use_residual: + assert in_dim == out_dim + # dont normalize if using conv + if not do_conv: + self.layernorm = nn.LayerNorm(in_dim) + + if do_conv: + self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1) + self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1) + else: + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + if not self.do_conv: + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class ZipperBlock(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + ): + super().__init__() + self.in_size = in_size + self.in_tokens = in_tokens + self.out_size = out_size + self.out_tokens = out_tokens + self.hidden_size = hidden_size + self.hidden_tokens = hidden_tokens + # permute to (batch_size, out_size, in_tokens) + + self.zip_token = MLPC( + in_dim=self.in_tokens, + out_dim=self.out_tokens, + hidden_dim=self.hidden_tokens, + do_conv=True, # no need to permute + use_residual=False + ) + + # permute to (batch_size, out_tokens, out_size) + + # in shpae: (batch_size, in_tokens, in_size) + self.zip_size = MLPC( + in_dim=self.in_size, + out_dim=self.out_size, + hidden_dim=self.hidden_size, + use_residual=False + ) + + def forward(self, x): + x = self.zip_token(x) + x = self.zip_size(x) + return x + + + + + + +# CLIPFusionModule +# Fuses any size of vision and text embeddings into a single embedding. +# remaps tokens and vectors. +class CLIPFusionModule(nn.Module): + def __init__( + self, + text_hidden_size: int = 768, + text_tokens: int = 77, + vision_hidden_size: int = 1024, + vision_tokens: int = 257, + num_blocks: int = 1, + ): + super(CLIPFusionModule, self).__init__() + + self.text_hidden_size = text_hidden_size + self.text_tokens = text_tokens + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + + self.resampler = ZipperBlock( + in_size=self.vision_hidden_size, + in_tokens=self.vision_tokens, + out_size=self.text_hidden_size, + out_tokens=self.text_tokens, + hidden_size=self.vision_hidden_size * 2, + hidden_tokens=self.vision_tokens * 2 + ) + + self.zipper_blocks = torch.nn.ModuleList([ + ZipperBlock( + in_size=self.text_hidden_size * 2, + in_tokens=self.text_tokens, + out_size=self.text_hidden_size, + out_tokens=self.text_tokens, + hidden_size=self.text_hidden_size * 2, + hidden_tokens=self.text_tokens * 2 + ) for i in range(num_blocks) + ]) + + self.ctx_alpha = ContextualAlphaMask( + dim=self.text_hidden_size, + ) + + self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01) + + def forward(self, text_embeds, vision_embeds): + # text_embeds = (batch_size, 77, 768) + # vision_embeds = (batch_size, 257, 1024) + # output = (batch_size, 77, 768) + + vision_embeds = self.resampler(vision_embeds) + x = vision_embeds + for i, block in enumerate(self.zipper_blocks): + res = x + x = torch.cat([text_embeds, x], dim=-1) + x = block(x) + x = x + res + + # alpha mask + ctx_alpha = self.ctx_alpha(text_embeds) + # reshape alpha to (1, 77, 1) + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + + x = ctx_alpha * x * alpha + + x = x + text_embeds + + return x diff --git a/ai-toolkit/toolkit/models/clip_pre_processor.py b/ai-toolkit/toolkit/models/clip_pre_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..7956da0b4d0a5d6b882d4d19d9458bf409cc9b39 --- /dev/null +++ b/ai-toolkit/toolkit/models/clip_pre_processor.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn + + +class UpsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_in = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + nn.GELU() + ) + self.conv_up = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.GELU() + ) + + self.conv_out = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + ) + + def forward(self, x): + x = self.conv_in(x) + x = self.conv_up(x) + x = self.conv_out(x) + return x + + +class CLIPImagePreProcessor(nn.Module): + def __init__( + self, + input_size=896, + clip_input_size=224, + downscale_factor: int = 16, + ): + super().__init__() + # make sure they are evenly divisible + assert input_size % clip_input_size == 0 + in_channels = 3 + + self.input_size = input_size + self.clip_input_size = clip_input_size + self.downscale_factor = downscale_factor + + subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768 + channels = subpixel_channels + + upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4 + + num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2 + + # make the residual down up blocks + self.upsample_blocks = nn.ModuleList() + self.subpixel_blocks = nn.ModuleList() + current_channels = channels + current_downscale = downscale_factor + for _ in range(num_upsample_blocks): + # determine the reshuffled channel count for this dimension + output_downscale = current_downscale // 2 + out_channels = in_channels * output_downscale ** 2 + # out_channels = current_channels // 2 + self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels)) + current_channels = out_channels + current_downscale = output_downscale + self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale)) + + # (bs, 768, 56, 56) -> (bs, 192, 112, 112) + # (bs, 192, 112, 112) -> (bs, 48, 224, 224) + + self.conv_out = nn.Conv2d( + current_channels, + out_channels=3, + kernel_size=3, + padding=1 + ) # (bs, 48, 224, 224) -> (bs, 3, 224, 224) + + # do a pooling layer to downscale the input to 1/3 of the size + # (bs, 3, 896, 896) -> (bs, 3, 224, 224) + kernel_size = input_size // clip_input_size + self.res_down = nn.AvgPool2d( + kernel_size=kernel_size, + stride=kernel_size + ) # (bs, 3, 896, 896) -> (bs, 3, 224, 224) + + # make a blending for output residual with near 0 weight + self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56) + + self.conv_in = nn.Sequential( + nn.Conv2d( + subpixel_channels, + channels, + kernel_size=3, + padding=1 + ), + nn.GELU() + ) # (bs, 768, 56, 56) -> (bs, 768, 56, 56) + + # make 2 deep blocks + + def forward(self, x): + inputs = x + # resize to input_size x input_size + x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic') + + res = self.res_down(inputs) + + x = self.unshuffle(x) + x = self.conv_in(x) + for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks): + x = up(x) + block_res = subpixel(inputs) + x = x + block_res + x = self.conv_out(x) + # blend residual + x = x * self.res_blend + res + return x diff --git a/ai-toolkit/toolkit/models/cogview4.py b/ai-toolkit/toolkit/models/cogview4.py new file mode 100644 index 0000000000000000000000000000000000000000..13fa42676d8a1a101a3aa029074de7e13425aa73 --- /dev/null +++ b/ai-toolkit/toolkit/models/cogview4.py @@ -0,0 +1,467 @@ +# DONT USE THIS!. IT DOES NOT WORK YET! +# Will revisit this when they release more info on how it was trained. + +import weakref +from diffusers import CogView4Pipeline +import torch +import yaml + +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +import torch +import diffusers +from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize, get_qtype +from transformers import GlmModel, AutoTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler +from typing import TYPE_CHECKING +from toolkit.accelerator import unwrap_model +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# remove this after a bug is fixed in diffusers code. This is a workaround. + + +class FakeModel: + def __init__(self, model): + self.model_ref = weakref.ref(model) + pass + + @property + def device(self): + return self.model_ref().device + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.25, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 0.75, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, + "time_shift_type": "linear", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + + +class CogView4(BaseModel): + arch = 'cogview4' + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['CogView4Transformer2DModel'] + + # cache for holding noise + self.effective_noise = None + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def load_model(self): + dtype = self.torch_dtype + base_model_path = "THUDM/CogView4-6B" + model_path = self.model_config.name_or_path + + self.print_and_status_update("Loading CogView4 model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading GlmModel") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = GlmModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing GlmModel") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + # hack to fix diffusers bug workaround + text_encoder.model = FakeModel(text_encoder) + + self.print_and_status_update("Loading transformer") + transformer = CogView4Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for CogViewModels models") + + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError( + "Assistant LoRA is not supported for CogViewModels models currently") + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for CogViewModels models currently") + + flush() + + if self.model_config.quantize: + quantization_args = self.model_config.quantize_kwargs + if 'exclude' not in quantization_args: + quantization_args['exclude'] = [] + if 'include' not in quantization_args: + quantization_args['include'] = [] + + # Be more specific with the include pattern to exactly match transformer blocks + quantization_args['include'] += ["transformer_blocks.*"] + + # Exclude all LayerNorm layers within transformer blocks + quantization_args['exclude'] += [ + "transformer_blocks.*.norm1", + "transformer_blocks.*.norm2", + "transformer_blocks.*.norm2_context", + "transformer_blocks.*.attn1.norm_q", + "transformer_blocks.*.attn1.norm_k" + ] + + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, **quantization_args) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = CogView4.get_train_scheduler() + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + self.print_and_status_update("Making pipe") + pipe: CogView4Pipeline = CogView4Pipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + self.pipeline = pipe + self.model = transformer + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + def get_generation_pipeline(self): + scheduler = CogView4.get_train_scheduler() + pipeline = CogView4Pipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + ) + return pipeline + + def generate_single_image( + self, + pipeline: CogView4Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + # target_size = (height, width) + target_size = latent_model_input.shape[-2:] + # multiply by 8 + target_size = (target_size[0] * 8, target_size[1] * 8) + crops_coords_top_left = torch.tensor( + [(0, 0)], dtype=self.torch_dtype, device=self.device_torch) + + original_size = torch.tensor( + [target_size], dtype=self.torch_dtype, device=self.device_torch) + target_size = original_size.clone() + noise_pred_cond = self.model( + hidden_states=latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + return noise_pred_cond + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + dtype=self.torch_dtype, + ) + return PromptEmbeds(prompt_embeds) + + def get_model_has_grad(self): + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: CogView4Transformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + effective_noise = self.effective_noise + batch = kwargs.get('batch') + if batch is None: + raise ValueError("Batch is not provided") + if noise is None: + raise ValueError("Noise is not provided") + # return batch.latents + # return (batch.latents - noise).detach() + return (noise - batch.latents).detach() + # return (batch.latents).detach() + # return (effective_noise - batch.latents).detach() + + def _get_low_res_latents(self, latents): + # todo prevent needing to do this and grab the tensor another way. + with torch.no_grad(): + # Decode latents to image space + images = self.decode_latents( + latents, device=latents.device, dtype=latents.dtype) + + # Downsample by a factor of 2 using bilinear interpolation + B, C, H, W = images.shape + low_res_images = torch.nn.functional.interpolate( + images, + size=(H // 2, W // 2), + mode="bilinear", + align_corners=False + ) + + # Upsample back to original resolution to match expected VAE input dimensions + upsampled_low_res_images = torch.nn.functional.interpolate( + low_res_images, + size=(H, W), + mode="bilinear", + align_corners=False + ) + + # Encode the low-resolution images back to latent space + low_res_latents = self.encode_images( + upsampled_low_res_images, device=latents.device, dtype=latents.dtype) + return low_res_latents + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents_chunks = None + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # # Flowmatching interpolation between original and noise + # if t > relay_start_point: + # # Standard flowmatching - direct linear interpolation + # noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx] + # effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise + # else: + # # Relay flowmatching case - only compute low_res_latents if needed + # if low_res_latents_chunks is None: + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Calculate the relay ratio (0 to 1) + # t_ratio = t.float() / relay_start_point + # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) + + # # First blend between original and low-res based on t_ratio + # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] + + # added_lor_res_noise = z0_t - original_samples_chunks[idx] + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] + # # lrln = lrln * (1 - t_01) + + # # make the noise an interpolation between noise and low_res_latents with + # # being noise at t_01=1 and low_res_latents at t_01=0 + # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln + # # new_noise = noise_chunks[idx] + lrln + # # new_noise = noise_chunks[idx] + lrln + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(new_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents diff --git a/ai-toolkit/toolkit/models/control_lora_adapter.py b/ai-toolkit/toolkit/models/control_lora_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..38147ea9a2996bad38a0221281f4788d4f71da7a --- /dev/null +++ b/ai-toolkit/toolkit/models/control_lora_adapter.py @@ -0,0 +1,272 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +# weakref + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + +# after each step we concat the control image with the latents +# latent_model_input = torch.cat([latents, control_image], dim=2) +# the x_embedder has a full rank lora to handle the additional channels +# this replaces the x_embedder with a full rank lora. on flux this is +# x_embedder(diffusers) or img_in(bfl) + +# Flux +# img_in.lora_A.weight [128, 128] +# img_in.lora_B.bias [3 072] +# img_in.lora_B.weight [3 072, 128] + + +class ImgEmbedder(torch.nn.Module): + def __init__( + self, + adapter: 'ControlLoraAdapter', + orig_layer: torch.nn.Linear, + in_channels=64, + out_channels=3072 + ): + super().__init__() + # only do the weight for the new input. We combine with the original linear layer + init = torch.randn(out_channels, in_channels, device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01 + self.weight = torch.nn.Parameter(init) + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) + + @classmethod + def from_model( + cls, + model: FluxTransformer2DModel, + adapter: 'ControlLoraAdapter', + num_control_images=1, + has_inpainting_input=False + ): + if model.__class__.__name__ == 'FluxTransformer2DModel': + num_adapter_in_channels = model.x_embedder.in_features * num_control_images + + if has_inpainting_input: + # inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask + # packed it is 64ch + 4ch mask + # so we need to add 4 to the input channels + num_adapter_in_channels += 4 + + x_embedder: torch.nn.Linear = model.x_embedder + img_embedder = cls( + adapter, + orig_layer=x_embedder, + in_channels=num_adapter_in_channels, + out_channels=x_embedder.out_features, + ) + + # hijack the forward method + x_embedder._orig_ctrl_lora_forward = x_embedder.forward + x_embedder.forward = img_embedder.forward + + # update the config of the transformer + model.config.in_channels = model.config.in_channels * (num_control_images + 1) + model.config["in_channels"] = model.config.in_channels + + return img_embedder + else: + raise ValueError("Model not supported") + + @property + def is_active(self): + return self.adapter_ref().is_active + + + def forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + + x = x.to(self.weight.device, dtype=self.weight.dtype) + + orig_weight = self.orig_layer_ref().weight.data.detach() + orig_weight = orig_weight.to(self.weight.device, dtype=self.weight.dtype) + linear_weight = torch.cat([orig_weight, self.weight], dim=1) + + bias = None + if self.orig_layer_ref().bias is not None: + bias = self.orig_layer_ref().bias.data.detach().to(self.weight.device, dtype=self.weight.dtype) + + x = torch.nn.functional.linear(x, linear_weight, bias) + + x = x.to(orig_device, dtype=orig_dtype) + return x + + + +class ControlLoraAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + config: 'AdapterConfig', + train_config: 'TrainConfig' + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.device_torch = sd.device_torch + self.control_lora = None + + if self.network_config is not None: + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + # always ignore x_embedder + network_kwargs['ignore_if_contains'].append('x_embedder') + + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() + + self.x_embedder = ImgEmbedder.from_model( + sd.unet, + self, + num_control_images=config.num_control_images, + has_inpainting_input=config.has_inpainting_input + ) + self.x_embedder.to(self.device_torch) + + def get_params(self): + if self.control_lora is not None: + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.x_embedder.to(torch.float32) + + params += list(self.x_embedder.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + img_embedder_sd = {} + for key, value in state_dict.items(): + if "x_embedder" in key: + new_key = key.replace("transformer.x_embedder.", "") + img_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading + if self.control_lora is not None: + self.control_lora.load_weights(lora_sd) + # automatically upgrade the x imbedder if more dims are added + if self.x_embedder.weight.shape[1] > img_embedder_sd['weight'].shape[1]: + print("Upgrading x_embedder from {} to {}".format( + img_embedder_sd['weight'].shape[1], + self.x_embedder.weight.shape[1] + )) + while img_embedder_sd['weight'].shape[1] < self.x_embedder.weight.shape[1]: + img_embedder_sd['weight'] = torch.cat([img_embedder_sd['weight'] ] * 2, dim=1) + if img_embedder_sd['weight'].shape[1] > self.x_embedder.weight.shape[1]: + img_embedder_sd['weight'] = img_embedder_sd['weight'][:, :self.x_embedder.weight.shape[1]] + self.x_embedder.load_state_dict(img_embedder_sd, strict=False) + + def get_state_dict(self): + if self.control_lora is not None: + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + # todo make sure we match loras elseware. + img_embedder_sd = self.x_embedder.state_dict() + for key, value in img_embedder_sd.items(): + lora_sd[f"transformer.x_embedder.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/ai-toolkit/toolkit/models/decorator.py b/ai-toolkit/toolkit/models/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..63f45aa9f944370727eed1a362c9bb04ad99fa9b --- /dev/null +++ b/ai-toolkit/toolkit/models/decorator.py @@ -0,0 +1,33 @@ +import torch + + +class Decorator(torch.nn.Module): + def __init__( + self, + num_tokens: int = 4, + token_size: int = 4096, + ) -> None: + super().__init__() + + self.weight: torch.nn.Parameter = torch.nn.Parameter( + torch.randn(num_tokens, token_size) + ) + # ensure it is float32 + self.weight.data = self.weight.data.float() + + def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor: + # make sure the param is float32 + if self.weight.dtype != text_embeds.dtype: + self.weight.data = self.weight.data.float() + # expand batch to match text_embeds + batch_size = text_embeds.shape[0] + decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1) + if is_unconditional: + # zero pad the decorator embeds + decorator_embeds = torch.zeros_like(decorator_embeds) + + if decorator_embeds.dtype != text_embeds.dtype: + decorator_embeds = decorator_embeds.to(text_embeds.dtype) + text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2) + + return text_embeds diff --git a/ai-toolkit/toolkit/models/diffusion_feature_extraction.py b/ai-toolkit/toolkit/models/diffusion_feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..28631bf0630568026f69d3051728d5e5b2708283 --- /dev/null +++ b/ai-toolkit/toolkit/models/diffusion_feature_extraction.py @@ -0,0 +1,1392 @@ +import math +import torch +import os +from torch import nn +from safetensors.torch import load_file +import torch.nn.functional as F +import torch.utils.checkpoint as ckpt +from diffusers import AutoencoderTiny +from transformers import AutoImageProcessor, AutoModel, SiglipImageProcessor, SiglipVisionModel +import lpips +import weakref +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.base_model import BaseModel +from toolkit.models.sapiens2 import Sapiens2 +import huggingface_hub + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.norm1 = nn.GroupNorm(8, out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) + self.norm2 = nn.GroupNorm(8, out_channels) + self.skip = nn.Conv2d(in_channels, out_channels, + 1) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + identity = self.skip(x) + x = self.conv1(x) + x = self.norm1(x) + x = F.silu(x) + x = self.conv2(x) + x = self.norm2(x) + x = F.silu(x + identity) + return x + + +class DiffusionFeatureExtractor2(nn.Module): + def __init__(self, in_channels=32): + super().__init__() + self.version = 2 + + # Path 1: Upsample to 512x512 (1, 64, 512, 512) + self.up_path = nn.ModuleList([ + nn.Conv2d(in_channels, 64, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(64, 64), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(64, 64), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(64, 64), + nn.Conv2d(64, 64, 3, padding=1), + ]) + + # Path 2: Upsample to 256x256 (1, 128, 256, 256) + self.path2 = nn.ModuleList([ + nn.Conv2d(in_channels, 128, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(128, 128), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(128, 128), + nn.Conv2d(128, 128, 3, padding=1), + ]) + + # Path 3: Upsample to 128x128 (1, 256, 128, 128) + self.path3 = nn.ModuleList([ + nn.Conv2d(in_channels, 256, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(256, 256), + nn.Conv2d(256, 256, 3, padding=1) + ]) + + # Path 4: Original size (1, 512, 64, 64) + self.path4 = nn.ModuleList([ + nn.Conv2d(in_channels, 512, 3, padding=1), + ResBlock(512, 512), + ResBlock(512, 512), + nn.Conv2d(512, 512, 3, padding=1) + ]) + + # Path 5: Downsample to 32x32 (1, 512, 32, 32) + self.path5 = nn.ModuleList([ + nn.Conv2d(in_channels, 512, 3, padding=1), + ResBlock(512, 512), + nn.AvgPool2d(2), + ResBlock(512, 512), + nn.Conv2d(512, 512, 3, padding=1) + ]) + + def forward(self, x): + outputs = [] + + # Path 1: 512x512 + x1 = x + for layer in self.up_path: + x1 = layer(x1) + outputs.append(x1) # [1, 64, 512, 512] + + # Path 2: 256x256 + x2 = x + for layer in self.path2: + x2 = layer(x2) + outputs.append(x2) # [1, 128, 256, 256] + + # Path 3: 128x128 + x3 = x + for layer in self.path3: + x3 = layer(x3) + outputs.append(x3) # [1, 256, 128, 128] + + # Path 4: 64x64 + x4 = x + for layer in self.path4: + x4 = layer(x4) + outputs.append(x4) # [1, 512, 64, 64] + + # Path 5: 32x32 + x5 = x + for layer in self.path5: + x5 = layer(x5) + outputs.append(x5) # [1, 512, 32, 32] + + return outputs + + +class DFEBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) + self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) + self.act = nn.GELU() + self.proj = nn.Conv2d(channels, channels, 1) + + def forward(self, x): + x_in = x + x = self.conv1(x) + x = self.conv2(x) + x = self.act(x) + x = self.proj(x) + x = x + x_in + return x + + +class DiffusionFeatureExtractor(nn.Module): + def __init__(self, in_channels=16, out_channels=512): + super().__init__() + self.version = 1 + num_blocks = 6 + self.conv_in = nn.Conv2d(in_channels, 512, 1) + self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) + self.conv_out = nn.Conv2d(512, out_channels, 1) + + def forward(self, x): + x = self.conv_in(x) + for block in self.blocks: + x = block(x) + x = self.conv_out(x) + return x + + +class DiffusionFeatureExtractor3(nn.Module): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__() + self.version = 3 + if vae is None: + vae = AutoencoderTiny.from_pretrained( + "madebyollin/taef1", torch_dtype=torch.bfloat16) + self.vae = vae + # image_encoder_path = "google/siglip-so400m-patch14-384" + image_encoder_path = "google/siglip2-so400m-patch16-512" + try: + self.image_processor = SiglipImageProcessor.from_pretrained( + image_encoder_path) + except EnvironmentError: + self.image_processor = SiglipImageProcessor() + self.vision_encoder = SiglipVisionModel.from_pretrained( + image_encoder_path, + ignore_mismatched_sizes=True + ).to(device, dtype=dtype) + + self.lpips_model = lpips_model = lpips.LPIPS(net='vgg') + self.lpips_model = lpips_model.to(device, dtype=torch.float32) + self.losses = {} + self.log_every = 100 + self.step = 0 + + def get_siglip_features(self, tensors_0_1): + dtype = torch.bfloat16 + device = self.vae.device + # resize to 384x384 + if 'height' in self.image_processor.size: + size = self.image_processor.size['height'] + else: + size = self.image_processor.crop_size['height'] + images = F.interpolate(tensors_0_1, size=(size, size), + mode='bicubic', align_corners=False) + + mean = torch.tensor(self.image_processor.image_mean).to( + device, dtype=dtype + ).detach() + std = torch.tensor(self.image_processor.image_std).to( + device, dtype=dtype + ).detach() + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = ( + images - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + + last_hidden_state = id_embeds['last_hidden_state'] + return last_hidden_state + + def get_lpips_features(self, tensors_0_1): + device = self.vae.device + tensors_n1p1 = (tensors_0_1 * 2) - 1 + def get_lpips_features(img): # -1 to 1 + in0_input = self.lpips_model.scaling_layer(img) + outs0 = self.lpips_model.net.forward(in0_input) + + feats0 = {} + + feats_list = [] + for kk in range(self.lpips_model.L): + feats0[kk] = lpips.normalize_tensor(outs0[kk]) + feats_list.append(feats0[kk]) + + # 512 in + # vgg + # 0 torch.Size([1, 64, 512, 512]) + # 1 torch.Size([1, 128, 256, 256]) + # 2 torch.Size([1, 256, 128, 128]) + # 3 torch.Size([1, 512, 64, 64]) + # 4 torch.Size([1, 512, 32, 32]) + + return feats_list + + # do lpips + lpips_feat_list = [x for x in get_lpips_features( + tensors_n1p1.to(device, dtype=torch.float32))] + + return lpips_feat_list + + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + # lpips_weight=1.0, + lpips_weight=10.0, + clip_weight=0.1, + pixel_weight=0.1, + model=None + ): + dtype = torch.bfloat16 + device = self.vae.device + + + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) + + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 + latents = (latents / scaling_factor) + shift_factor + tensors_n1p1 = self.vae.decode(latents) # -1 to 1 + if hasattr(tensors_n1p1, 'sample'): + tensors_n1p1 = tensors_n1p1.sample + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + lpips_feat_list_pred = self.get_lpips_features(pred_images.float()) + + total_loss = 0 + + with torch.no_grad(): + target_img = batch.tensor.to(device, dtype=dtype) + # go from -1 to 1 to 0 to 1 + target_img = (target_img + 1) / 2 + lpips_feat_list_target = self.get_lpips_features(target_img.float()) + if clip_weight > 0: + target_clip_output = self.get_siglip_features(target_img).detach() + if clip_weight > 0: + pred_clip_output = self.get_siglip_features(pred_images) + clip_loss = torch.nn.functional.mse_loss( + pred_clip_output.float(), target_clip_output.float() + ) * clip_weight + + if 'clip_loss' not in self.losses: + self.losses['clip_loss'] = clip_loss.item() + else: + self.losses['clip_loss'] += clip_loss.item() + + total_loss += clip_loss + + skip_lpips_layers = [] + + lpips_loss = 0 + for idx, lpips_feat in enumerate(lpips_feat_list_pred): + if idx in skip_lpips_layers: + continue + lpips_loss += torch.nn.functional.mse_loss( + lpips_feat.float(), lpips_feat_list_target[idx].float() + ) * lpips_weight + + if f'lpips_loss_{idx}' not in self.losses: + self.losses[f'lpips_loss_{idx}'] = lpips_loss.item() + else: + self.losses[f'lpips_loss_{idx}'] += lpips_loss.item() + + total_loss += lpips_loss + + # mse_loss = torch.nn.functional.mse_loss( + # stepped_latents.float(), batch.latents.float() + # ) * pixel_weight + + # if 'pixel_loss' not in self.losses: + # self.losses['pixel_loss'] = mse_loss.item() + # else: + # self.losses['pixel_loss'] += mse_loss.item() + + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return total_loss + +class DiffusionFeatureExtractor4(nn.Module): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__() + self.version = 4 + if vae is None: + raise ValueError("vae must be provided for DFE4") + self.vae = vae + # image_encoder_path = "google/siglip-so400m-patch14-384" + image_encoder_path = "google/siglip2-so400m-patch16-naflex" + from transformers import Siglip2ImageProcessor, Siglip2VisionModel + try: + self.image_processor = Siglip2ImageProcessor.from_pretrained( + image_encoder_path) + except EnvironmentError: + self.image_processor = Siglip2ImageProcessor() + + self.image_processor.max_num_patches = 1024 + + self.vision_encoder = Siglip2VisionModel.from_pretrained( + image_encoder_path, + ignore_mismatched_sizes=True + ).to(device, dtype=dtype) + + self.losses = {} + self.log_every = 100 + self.step = 0 + + def _target_hw(self, h, w, patch, max_patches, eps: float = 1e-5): + def _snap(x, s): + x = math.ceil((x * s) / patch) * patch + return max(patch, int(x)) + + lo, hi = eps / 10, 1.0 + while hi - lo >= eps: + mid = (lo + hi) / 2 + th, tw = _snap(h, mid), _snap(w, mid) + if (th // patch) * (tw // patch) <= max_patches: + lo = mid + else: + hi = mid + return _snap(h, lo), _snap(w, lo) + + + def tensors_to_siglip_like_features(self, batch: torch.Tensor): + """ + Args: + batch: (bs, 3, H, W) tensor already in the desired value range + (e.g. [-1, 1] or [0, 1]); no extra rescale / normalize here. + + Returns: + dict( + pixel_values – (bs, L, P) where L = n_h*n_w, P = 3*patch*patch + pixel_attention_mask– (L,) all-ones + spatial_shapes – (n_h, n_w) + ) + """ + if batch.ndim != 4: + raise ValueError("Expected (bs, 3, H, W) tensor") + + bs, c, H, W = batch.shape + proc = self.image_processor + patch = proc.patch_size + max_patches = proc.max_num_patches + + # One shared resize for the whole batch + tgt_h, tgt_w = self._target_hw(H, W, patch, max_patches) + batch = torch.nn.functional.interpolate( + batch, size=(tgt_h, tgt_w), mode="bilinear", align_corners=False + ) + + n_h, n_w = tgt_h // patch, tgt_w // patch + # flat_dim = c * patch * patch + num_p = n_h * n_w + + # unfold → (bs, flat_dim, num_p) → (bs, num_p, flat_dim) + patches = ( + torch.nn.functional.unfold(batch, kernel_size=patch, stride=patch) + .transpose(1, 2) + ) + + attn_mask = torch.ones(num_p, dtype=torch.long, device=batch.device) + spatial = torch.tensor((n_h, n_w), device=batch.device, dtype=torch.int32) + + # repeat attn_mask for each batch element + attn_mask = attn_mask.unsqueeze(0).repeat(bs, 1) + spatial = spatial.unsqueeze(0).repeat(bs, 1) + + return { + "pixel_values": patches, # (bs, num_patches, patch_dim) + "pixel_attention_mask": attn_mask, # (num_patches,) + "spatial_shapes": spatial + } + + def get_siglip_features(self, tensors_0_1): + dtype = torch.bfloat16 + device = self.vae.device + + tensors_0_1 = torch.clamp(tensors_0_1, 0.0, 1.0) + + mean = torch.tensor(self.image_processor.image_mean).to( + device, dtype=dtype + ).detach() + std = torch.tensor(self.image_processor.image_std).to( + device, dtype=dtype + ).detach() + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + encoder_kwargs = self.tensors_to_siglip_like_features(clip_image) + id_embeds = self.vision_encoder( + pixel_values=encoder_kwargs['pixel_values'], + pixel_attention_mask=encoder_kwargs['pixel_attention_mask'], + spatial_shapes=encoder_kwargs['spatial_shapes'], + output_hidden_states=True, + ) + + image_embeds = id_embeds['hidden_states'][-2] # penultimate layer + # image_embeds = id_embeds['pooler_output'] + # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + return image_embeds + + def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler): + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + return stepped_latents + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + clip_weight=1.0, + mse_weight=0.0, + model=None + ): + dtype = torch.bfloat16 + device = self.vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + stepped_latents = self.step_latents(noise, noise_pred, noisy_latents, timesteps, scheduler) + + latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) + + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 + latents = (latents / scaling_factor) + shift_factor + if is_video: + # if video, we need to unsqueeze the latents to match the vae input shape + latents = latents.unsqueeze(2) + tensors_n1p1 = self.vae.decode(latents) # -1 to 1 + if hasattr(tensors_n1p1, 'sample'): + tensors_n1p1 = tensors_n1p1.sample + + if is_video: + # if video, we need to squeeze the tensors to match the output shape + tensors_n1p1 = tensors_n1p1.squeeze(2) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + total_loss = 0 + + with torch.no_grad(): + target_img = tensors.to(device, dtype=dtype) + # go from -1 to 1 to 0 to 1 + target_img = (target_img + 1) / 2 + if clip_weight > 0: + target_clip_output = self.get_siglip_features(target_img).detach() + if clip_weight > 0: + pred_clip_output = self.get_siglip_features(pred_images) + clip_loss = torch.nn.functional.mse_loss( + pred_clip_output.float(), target_clip_output.float() + ) * clip_weight + + if 'clip_loss' not in self.losses: + self.losses['clip_loss'] = clip_loss.item() + else: + self.losses['clip_loss'] += clip_loss.item() + + total_loss += clip_loss + if mse_weight > 0: + mse_loss = torch.nn.functional.mse_loss( + pred_images.float(), target_img.float() + ) * mse_weight + + if 'mse_loss' not in self.losses: + self.losses['mse_loss'] = mse_loss.item() + else: + self.losses['mse_loss'] += mse_loss.item() + + total_loss += mse_loss + + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return total_loss + +class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__(device=device, dtype=dtype, vae=vae) + self.version = 5 + + def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler, total_steps: int = 1000, eps: float = 1e-6): + bs = noise_pred.shape[0] + + # Chunk inputs per-sample (keeps existing structure) + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + noise_chunks = torch.chunk(noise, bs) + + stepped_chunks = [] + x0_pred_chunks = [] + + for idx in range(bs): + model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) + timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) + sample = noisy_latent_chunks[idx].to(torch.float32) + noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) + + # Initialize scheduler step index for this sample + scheduler._step_index = None + scheduler._init_step_index(timestep) + + # ---- Step +50 indices (or to the end) in sigma-space ---- + sigma = scheduler.sigmas[scheduler.step_index] + target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) + sigma_next = scheduler.sigmas[target_idx] + + # One-step update along the model-predicted direction + stepped = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(stepped) + + # ---- Inverse-Gaussian recovery at the target timestep ---- + t_01 = (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) + original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) + x0_pred_chunks.append(original_samples) + + # stepped_latents = torch.cat(stepped_chunks, dim=0) + predicted_images = torch.cat(x0_pred_chunks, dim=0) + # return stepped_latents, predicted_images + return predicted_images + + +class DiffusionFeatureExtractor6(nn.Module): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__() + self.version = 6 + if vae is None: + raise ValueError("vae must be provided for DFE4") + self.vae = vae + # pretrained_model_name = "facebook/dinov3-vits16-pretrain-lvd1689m" + # pretrained_model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m" + pretrained_model_name = "facebook/dinov3-vith16plus-pretrain-lvd1689m" + # pretrained_model_name = "facebook/dinov3-vit7b16-pretrain-lvd1689m" + self.processor = AutoImageProcessor.from_pretrained(pretrained_model_name) + self.model = AutoModel.from_pretrained( + pretrained_model_name, + device_map=device, + dtype=dtype, + ).to(device, dtype=dtype) + + self.losses = {} + self.log_every = 100 + self.step = 0 + + def prepare_inputs(self, tensor_0_1: torch.Tensor): + """ + tensor_0_1: (bs, 3, h, w), float, values in [0, 1] + returns: {"pixel_values": (bs, 3, H, W)} ready for the vision transformer + """ + + if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3: + raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}") + + x = tensor_0_1 + if not torch.is_floating_point(x): + x = x.float() + + # Resize + # if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches + max_res = 512 + p = 16 + if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res): + target_h = x.shape[-2] + target_w = x.shape[-1] + if x.shape[-1] * target_h > max_res * max_res: + scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h)) + target_h = int(target_h * scale_factor) + target_w = int(target_w * scale_factor) + target_h = (target_h // p) * p + target_w = (target_w // p) * p + x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False) + + # Rescale (HF processors usually assume uint8 0..255 inputs; your inputs are already 0..1) + if self.processor.do_rescale: + # If it looks like [0..1], skip to avoid double-scaling. + # If user accidentally passed 0..255 floats, this will fix it. + if x.detach().max().item() > 1.0 + 1e-6: + x = x * float(self.processor.rescale_factor or 1.0 / 255.0) + + # Normalize + if self.processor.do_normalize: + mean = torch.tensor(self.processor.image_mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = torch.tensor(self.processor.image_std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + x = (x - mean) / std + + return {"pixel_values": x} + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + model=None + ): + dtype = torch.bfloat16 + device = self.vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + # min 0.001 + tv = torch.clamp(tv, min=0.001) + + # step latent + x0 = noisy_latents - tv * noise_pred + + stepped_latents = x0 + + latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) + + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 + latents = (latents / scaling_factor) + shift_factor + if is_video: + # if video, we need to unsqueeze the latents to match the vae input shape + latents = latents.unsqueeze(2) + tensors_n1p1 = self.vae.decode(latents) # -1 to 1 + if hasattr(tensors_n1p1, 'sample'): + tensors_n1p1 = tensors_n1p1.sample + + if is_video: + # if video, we need to squeeze the tensors to match the output shape + tensors_n1p1 = tensors_n1p1.squeeze(2) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + with torch.no_grad(): + target_img = tensors.to(device, dtype=dtype) + # go from -1 to 1 to 0 to 1 + target_img = (target_img + 1) / 2 + target_dino_input = self.prepare_inputs(target_img) + target_dino_output = self.model(**target_dino_input).pooler_output.detach() + # normalize + target_dino_output = (target_dino_output - target_dino_output.mean()) / (target_dino_output.std() + 1e-6) + pred_dino_input = self.prepare_inputs(pred_images) + pred_dino_output = self.model(**pred_dino_input).pooler_output + # normalize + pred_dino_output = (pred_dino_output - pred_dino_output.mean()) / (pred_dino_output.std() + 1e-6) + dino_loss = torch.nn.functional.mse_loss( + pred_dino_output.float(), target_dino_output.float() + ) + + if 'dinov3' not in self.losses: + self.losses['dinov3'] = dino_loss.item() + else: + self.losses['dinov3'] += dino_loss.item() + + with torch.no_grad(): + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return dino_loss + +class ModelOutputWrapper: + def __init__(self, head, depth, normals, segmentation): + self.head = head + self.depth = depth + self.normals = normals + self.segmentation = segmentation + + +class DiffusionFeatureExtractor7(nn.Module): + def __init__( + self, + device=torch.device("cuda"), + dtype=torch.bfloat16, + vae=None, + sd=None, + partial_step: bool = False + ): + super().__init__() + + self.version = 7 + self.sd_ref = weakref.ref(sd) if sd is not None else None + from toolkit.models.tipsv2 import TIPSv2DPTModel + pretrained_model_name = "google/tipsv2-b14-dpt" + self.model = TIPSv2DPTModel.from_pretrained( + pretrained_model_name, + dtype=dtype, + ).to(device, dtype=dtype) + + self.losses = {} + self.log_every = 100 + self.step = 0 + self.do_partial_step = partial_step + + def get_pred(self, tensor_0_1: torch.Tensor): + """ + tensor_0_1: (bs, 3, h, w), float, values in [0, 1] + returns: {"pixel_values": (bs, 3, H, W)} ready for the vision transformer + """ + + if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3: + raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}") + + x = tensor_0_1.to(self.model.device, dtype=self.model.dtype) + + # Resize + # if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches + max_res = 1024 + p = 14 + if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res): + target_h = x.shape[-2] + target_w = x.shape[-1] + if x.shape[-1] * target_h > max_res * max_res: + scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h)) + target_h = int(target_h * scale_factor) + target_w = int(target_w * scale_factor) + target_h = (target_h // p) * p + target_w = (target_w // p) * p + x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False) + + # do inference. us standard dpy but also the head + pixel_values = x.to(self.model.device, dtype=self.model.dtype) + h, w = pixel_values.shape[2:] + dpt_inputs = self.model._extract_intermediate(pixel_values) + # head is a list of 4 + # each of the 4 is a tuple of (embeds, hidden_state) + # concat the hidden states from the 4 layers on dim 1 + head = torch.cat([h[1] for h in dpt_inputs], dim=1) + return ModelOutputWrapper( + head=head, + depth=self.model.depth_head(dpt_inputs, image_size=(h, w)), + normals=self.model.normals_head(dpt_inputs, image_size=(h, w)), + segmentation=self.model.segmentation_head(dpt_inputs, image_size=(h, w)), + ) + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + model=None + ): + dtype = torch.bfloat16 + device = self.sd_ref().vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + + with torch.no_grad(): + target_0_1 = (tensors + 1) / 2 # 0 to 1 + + if not self.do_partial_step: + # step latent + x0 = noisy_latents - tv * noise_pred + stepped_latents = x0 + # min 0.001 + tv = torch.clamp(tv, min=0.001) + else: + # step is random 0.05 to 0.02 + step = torch.rand_like(tv) * 0.03 + 0.02 + next_step = tv - step + next_step = torch.clamp(next_step, min=0.0) + stepped_latents = noisy_latents + (next_step - tv) * noise_pred + + with torch.no_grad(): + # make a noisy target at next timestep + target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + # add noise + target_latents = (1.0 - next_step) * target_latents + next_step * noise + target_n1p1 = self.sd_ref().decode_latents(target_latents) + target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1 + + latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + + tensors_n1p1 = self.sd_ref().decode_latents(latents) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + device = self.model.device + dtype = self.model.dtype + + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + + with torch.no_grad(): + target = self.get_pred(target_0_1) + + pred_images = pred_images.to(device, dtype=dtype) + pred = self.get_pred(pred_images) + + head_loss = torch.nn.functional.mse_loss( + pred.head.float(), target.head.float(), reduction='none' + ) * velocity_equiv_weight + head_loss = head_loss.mean() + + depth_loss = torch.nn.functional.l1_loss( + pred.depth.float(), target.depth.float(), reduction='none' + ) * velocity_equiv_weight + depth_loss = depth_loss.mean() + + normals_loss = torch.nn.functional.l1_loss( + pred.normals.float(), target.normals.float(), reduction='none' + ) * velocity_equiv_weight + normals_loss = normals_loss.mean() + + segmentation_loss = torch.nn.functional.l1_loss( + pred.segmentation.float(), target.segmentation.float(), reduction='none' + ) * velocity_equiv_weight + segmentation_loss = segmentation_loss.mean() + + total_loss = (head_loss + depth_loss + normals_loss + segmentation_loss) / 4.0 + + if self.do_partial_step: + total_loss = total_loss * 10.0 + + if 'total' not in self.losses: + self.losses['total'] = total_loss.item() + else: + self.losses['total'] += total_loss.item() + + if 'head' not in self.losses: + self.losses['head'] = head_loss.item() + else: + self.losses['head'] += head_loss.item() + + if 'depth' not in self.losses: + self.losses['depth'] = depth_loss.item() + else: + self.losses['depth'] += depth_loss.item() + + if 'normals' not in self.losses: + self.losses['normals'] = normals_loss.item() + else: + self.losses['normals'] += normals_loss.item() + + if 'segmentation' not in self.losses: + self.losses['segmentation'] = segmentation_loss.item() + else: + self.losses['segmentation'] += segmentation_loss.item() + + with torch.no_grad(): + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return total_loss + +class DiffusionFeatureExtractor8(DiffusionFeatureExtractor7): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None): + super().__init__(device=device, dtype=dtype, vae=vae, sd=sd, partial_step=True) + self.version = 8 + +class DiffusionFeatureExtractor9(nn.Module): + def __init__( + self, + device=torch.device("cuda"), + dtype=torch.bfloat16, + vae=None, + sd=None, + partial_step: bool = False + ): + super().__init__() + + self.version = 9 + self.sd_ref = weakref.ref(sd) if sd is not None else None + ckpt_path = huggingface_hub.hf_hub_download(repo_id="facebook/sapiens2-pretrain-1b", filename="sapiens2_1b_pretrain.safetensors") + self.model = Sapiens2(arch="sapiens2_1b", img_size=(1024, 768), patch_size=16).eval().cuda() # img_size is (H, W) + self.model.load_state_dict(load_file(ckpt_path)) + self.model.to(device, dtype=dtype) + + self.losses = {} + self.log_every = 100 + self.step = 0 + self.do_partial_step = partial_step + + def get_pred(self, tensor_0_1: torch.Tensor): + if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3: + raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}") + + x = tensor_0_1.to(self.model.device, dtype=self.model.dtype) + """Apply ImageNet normalization to a (B, C, H, W) RGB tensor in [0, 1].""" + mean = torch.as_tensor((0.485, 0.456, 0.406), dtype=x.dtype, device=x.device).view(1, 3, 1, 1) + std = torch.as_tensor((0.229, 0.224, 0.225), dtype=x.dtype, device=x.device).view(1, 3, 1, 1) + x = (x - mean) / std + + # Resize + # if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches + max_res = 1024 + p = 16 + if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res): + target_h = x.shape[-2] + target_w = x.shape[-1] + if x.shape[-1] * target_h > max_res * max_res: + scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h)) + target_h = int(target_h * scale_factor) + target_w = int(target_w * scale_factor) + target_h = (target_h // p) * p + target_w = (target_w // p) * p + x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False) + x = x.to(self.model.device, dtype=self.model.dtype) + features = self.model(x)[0] + return features + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + model=None + ): + dtype = torch.bfloat16 + device = self.sd_ref().vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + + with torch.no_grad(): + target_0_1 = (tensors + 1) / 2 # 0 to 1 + + if not self.do_partial_step: + # step latent + x0 = noisy_latents - tv * noise_pred + stepped_latents = x0 + # min 0.001 + tv = torch.clamp(tv, min=0.001) + else: + # step is random 0.1 to 0.25 + step = torch.rand_like(tv) * 0.15 + 0.1 + next_step = tv - step + next_step = torch.clamp(next_step, min=0.0) + stepped_latents = noisy_latents + (next_step - tv) * noise_pred + + with torch.no_grad(): + # make a noisy target at next timestep + target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + # add noise + target_latents = (1.0 - next_step) * target_latents + next_step * noise + target_n1p1 = self.sd_ref().decode_latents(target_latents) + target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1 + + latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + + tensors_n1p1 = self.sd_ref().decode_latents(latents) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + device = self.model.device + dtype = self.model.dtype + + with torch.no_grad(): + target = self.get_pred(target_0_1) + + pred_images = pred_images.to(device, dtype=dtype) + pred = self.get_pred(pred_images) + + perceptual_loss = torch.nn.functional.mse_loss( + pred.float(), target.float(), reduction="none" + ) + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + loss_perceptual = (perceptual_loss * velocity_equiv_weight).mean() + + if self.do_partial_step: + loss_perceptual = loss_perceptual * 10.0 + + if 'loss' not in self.losses: + self.losses['loss'] = loss_perceptual.item() + else: + self.losses['loss'] += loss_perceptual.item() + with torch.no_grad(): + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return loss_perceptual + +class DiffusionFeatureExtractor10(nn.Module): + def __init__( + self, + device=torch.device("cuda"), + dtype=torch.bfloat16, + vae=None, + sd=None, + partial_step: bool = False + ): + super().__init__() + + self.version = 10 + self.sd_ref = weakref.ref(sd) if sd is not None else None + self.lpips_model = lpips.LPIPS(net='vgg') + self.lpips_model = self.lpips_model.to(device, dtype=torch.float32) + + self.losses = {} + self.log_every = 100 + self.step = 0 + self.do_partial_step = partial_step + + def _vgg_slices(self, x): + # run the lpips vgg backbone slice-by-slice so we can gradient + # checkpoint each slice. checkpointing activates whenever grads are + # enabled, so it does not require the module to be in train mode. + net = self.lpips_model.net + slices = [net.slice1, net.slice2, net.slice3, net.slice4, net.slice5] + outs = [] + h = x + for s in slices: + if torch.is_grad_enabled(): + h = ckpt.checkpoint(s, h, use_reentrant=False) + else: + h = s(h) + outs.append(h) + return outs + + def get_lpips_features(self, tensors_0_1): + device = self.lpips_model.scaling_layer.shift.device + tensors_n1p1 = (tensors_0_1 * 2) - 1 + def get_lpips_features(img): # -1 to 1 + in0_input = self.lpips_model.scaling_layer(img) + outs0 = self._vgg_slices(in0_input) + + feats_list = [] + for kk in range(self.lpips_model.L): + feats_list.append(lpips.normalize_tensor(outs0[kk])) + + return feats_list + + lpips_feat_list = [x for x in get_lpips_features( + tensors_n1p1.to(device, dtype=torch.float32))] + + return lpips_feat_list + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + model=None + ): + dtype = torch.bfloat16 + device = self.sd_ref().vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + + with torch.no_grad(): + target_0_1 = (tensors + 1) / 2 # 0 to 1 + + if not self.do_partial_step: + # step latent + x0 = noisy_latents - tv * noise_pred + stepped_latents = x0 + # min 0.001 + tv = torch.clamp(tv, min=0.001) + else: + # step is random 0.1 to 0.25 + step = torch.rand_like(tv) * 0.15 + 0.1 + next_step = tv - step + next_step = torch.clamp(next_step, min=0.0) + stepped_latents = noisy_latents + (next_step - tv) * noise_pred + + with torch.no_grad(): + # make a noisy target at next timestep + target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + # add noise + target_latents = (1.0 - next_step) * target_latents + next_step * noise + target_n1p1 = self.sd_ref().decode_latents(target_latents) + target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1 + + latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + + tensors_n1p1 = self.sd_ref().decode_latents(latents) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + with torch.no_grad(): + target_feats = self.get_lpips_features(target_0_1.float()) + + pred_feats = self.get_lpips_features(pred_images.float()) + + velocity_equiv_weight = (1.0 / torch.clamp(tv, min=0.1) ** 2) + + loss_perceptual = 0 + for idx, pred_feat in enumerate(pred_feats): + perceptual_loss = torch.nn.functional.mse_loss( + pred_feat.float(), target_feats[idx].float(), reduction="none" + ) + # mean over channels/spatial per sample, keep batch dim to weight by timestep + perceptual_loss = perceptual_loss.mean(dim=[1, 2, 3], keepdim=True) + loss_perceptual = loss_perceptual + (perceptual_loss * velocity_equiv_weight).mean() + + if self.do_partial_step: + loss_perceptual = loss_perceptual * 10.0 + + if 'loss' not in self.losses: + self.losses['loss'] = loss_perceptual.item() + else: + self.losses['loss'] += loss_perceptual.item() + with torch.no_grad(): + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return loss_perceptual + +def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureExtractor: + if model_path == "v3": + dfe = DiffusionFeatureExtractor3(vae=vae) + dfe.eval() + return dfe + if model_path == "v4": + dfe = DiffusionFeatureExtractor4(vae=vae) + dfe.eval() + return dfe + if model_path == "v5": + dfe = DiffusionFeatureExtractor5(vae=vae) + dfe.eval() + return dfe + if model_path == "v6": + dfe = DiffusionFeatureExtractor6(vae=vae) + dfe.eval() + return dfe + if model_path == "v7": + dfe = DiffusionFeatureExtractor7(vae=vae, sd=sd) + dfe.eval() + return dfe + if model_path == "v8": + dfe = DiffusionFeatureExtractor8(vae=vae, sd=sd) + dfe.eval() + return dfe + if model_path == "v9": + dfe = DiffusionFeatureExtractor9(vae=vae, sd=sd) + dfe.eval() + return dfe + if model_path == "v10": + dfe = DiffusionFeatureExtractor10(vae=vae, sd=sd) + dfe.eval() + return dfe + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + # if it ende with safetensors + if model_path.endswith('.safetensors'): + state_dict = load_file(model_path) + else: + state_dict = torch.load(model_path, weights_only=True) + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + + if 'conv_in.weight' in state_dict: + # determine num out channels + out_channels = state_dict['conv_out.weight'].shape[0] + dfe = DiffusionFeatureExtractor(out_channels=out_channels) + else: + dfe = DiffusionFeatureExtractor2() + + dfe.load_state_dict(state_dict) + dfe.eval() + return dfe diff --git a/ai-toolkit/toolkit/models/flux.py b/ai-toolkit/toolkit/models/flux.py new file mode 100644 index 0000000000000000000000000000000000000000..0241ce2fffb06d3a7f3b4affd446866100779636 --- /dev/null +++ b/ai-toolkit/toolkit/models/flux.py @@ -0,0 +1,178 @@ + +# forward that bypasses the guidance embedding so it can be avoided during training. +from functools import partial +from typing import Optional +import torch +from diffusers import FluxTransformer2DModel +from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings + + +def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + return conditioning + +# bypass the forward function + + +def bypass_flux_guidance(transformer): + if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + # dont bypass if it doesnt have the guidance embedding + if not hasattr(transformer.time_text_embed, 'guidance_embedder'): + return + transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward + transformer.time_text_embed.forward = partial( + guidance_embed_bypass_forward, transformer.time_text_embed + ) + +# restore the forward function + + +def restore_flux_guidance(transformer): + if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward + del transformer.time_text_embed._bfg_orig_forward + +def new_device_to(self: FluxTransformer2DModel, *args, **kwargs): + # Store original device if provided in args or kwargs + device_in_kwargs = 'device' in kwargs + device_in_args = any(isinstance(arg, (str, torch.device)) for arg in args) + + device = None + # Remove device from kwargs if present + if device_in_kwargs: + device = kwargs['device'] + del kwargs['device'] + + # Only filter args if we detected a device argument + if device_in_args: + args = list(args) + for idx, arg in enumerate(args): + if isinstance(arg, (str, torch.device)): + device = arg + del args[idx] + + self.pos_embed = self.pos_embed.to(device, *args, **kwargs) + self.time_text_embed = self.time_text_embed.to(device, *args, **kwargs) + self.context_embedder = self.context_embedder.to(device, *args, **kwargs) + self.x_embedder = self.x_embedder.to(device, *args, **kwargs) + for block in self.transformer_blocks: + block.to(block._split_device, *args, **kwargs) + for block in self.single_transformer_blocks: + block.to(block._split_device, *args, **kwargs) + + self.norm_out = self.norm_out.to(device, *args, **kwargs) + self.proj_out = self.proj_out.to(device, *args, **kwargs) + + + + return self + + + + +def split_gpu_double_block_forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, +): + if hidden_states.device != self._split_device: + hidden_states = hidden_states.to(self._split_device) + if encoder_hidden_states.device != self._split_device: + encoder_hidden_states = encoder_hidden_states.to(self._split_device) + if temb.device != self._split_device: + temb = temb.to(self._split_device) + if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device: + # is a tuple of tensors + image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb]) + return self._pre_gpu_split_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) + + +def split_gpu_single_block_forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + **kwargs +): + if hidden_states.device != self._split_device: + hidden_states = hidden_states.to(device=self._split_device) + if temb.device != self._split_device: + temb = temb.to(device=self._split_device) + if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device: + # is a tuple of tensors + image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb]) + + hidden_state_out = self._pre_gpu_split_forward(hidden_states, temb, image_rotary_emb, joint_attention_kwargs, **kwargs) + if hasattr(self, "_split_output_device"): + return hidden_state_out.to(self._split_output_device) + return hidden_state_out + + +def add_model_gpu_splitter_to_flux( + transformer: FluxTransformer2DModel, + # ~ 5 billion for all other params + other_module_params: Optional[int] = 5e9, + # since they are not trainable, multiply by smaller number + other_module_param_count_scale: Optional[float] = 0.3 +): + gpu_id_list = [i for i in range(torch.cuda.device_count())] + + # if len(gpu_id_list) > 2: + # raise ValueError("Cannot split to more than 2 GPUs currently.") + other_module_params *= other_module_param_count_scale + + # since we are not tuning the + total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params + + params_per_gpu = total_params / len(gpu_id_list) + + current_gpu_idx = 0 + # text encoders, vae, and some non block layers will all be on gpu 0 + current_gpu_params = other_module_params + + for double_block in transformer.transformer_blocks: + device = torch.device(f"cuda:{current_gpu_idx}") + double_block._pre_gpu_split_forward = double_block.forward + double_block.forward = partial( + split_gpu_double_block_forward, double_block) + double_block._split_device = device + # add the params to the current gpu + current_gpu_params += sum(p.numel() for p in double_block.parameters()) + # if the current gpu params are greater than the params per gpu, move to next gpu + if current_gpu_params > params_per_gpu: + current_gpu_idx += 1 + current_gpu_params = 0 + if current_gpu_idx >= len(gpu_id_list): + current_gpu_idx = gpu_id_list[-1] + + for single_block in transformer.single_transformer_blocks: + device = torch.device(f"cuda:{current_gpu_idx}") + single_block._pre_gpu_split_forward = single_block.forward + single_block.forward = partial( + split_gpu_single_block_forward, single_block) + single_block._split_device = device + # add the params to the current gpu + current_gpu_params += sum(p.numel() for p in single_block.parameters()) + # if the current gpu params are greater than the params per gpu, move to next gpu + if current_gpu_params > params_per_gpu: + current_gpu_idx += 1 + current_gpu_params = 0 + if current_gpu_idx >= len(gpu_id_list): + current_gpu_idx = gpu_id_list[-1] + + # add output device to last layer + transformer.single_transformer_blocks[-1]._split_output_device = torch.device("cuda:0") + + transformer._pre_gpu_split_to = transformer.to + transformer.to = partial(new_device_to, transformer) + diff --git a/ai-toolkit/toolkit/models/flux_sage_attn.py b/ai-toolkit/toolkit/models/flux_sage_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..930a17000c92e7bef14d141c22e9acfef8f92bf4 --- /dev/null +++ b/ai-toolkit/toolkit/models/flux_sage_attn.py @@ -0,0 +1,94 @@ +from typing import Optional +from diffusers.models.attention_processor import Attention +import torch +import torch.nn.functional as F + + +class FluxSageAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + from sageattention import sageattn + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = sageattn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states \ No newline at end of file diff --git a/ai-toolkit/toolkit/models/i2v_adapter.py b/ai-toolkit/toolkit/models/i2v_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..27bc7238c64ab490635517233a884bd740f5332f --- /dev/null +++ b/ai-toolkit/toolkit/models/i2v_adapter.py @@ -0,0 +1,586 @@ +from functools import partial +import inspect +import weakref +import torch +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import WanTransformer3DModel +from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding +from toolkit.util.shuffle import shuffle_tensor_along_axis +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + +class FrameEmbedder(torch.nn.Module): + def __init__( + self, + adapter: 'I2VAdapter', + orig_layer: torch.nn.Conv3d, + in_channels=20, # wan is 16 normally, and 36 with i2v so 20 new channels + ): + super().__init__() + # goes through a conv patch embedding first and is then flattened + # hidden_states = self.patch_embedding(hidden_states) + # hidden_states = hidden_states.flatten(2).transpose(1, 2) + + inner_dim = orig_layer.out_channels + patch_size = adapter.sd_ref().model.config.patch_size + + self.patch_embedding = torch.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) + + @classmethod + def from_model( + cls, + model: WanTransformer3DModel, + adapter: 'I2VAdapter', + ): + if model.__class__.__name__ == 'WanTransformer3DModel': + new_channels = 20 # wan is 16 normally, and 36 with i2v so 20 new channels + + orig_patch_embedding: torch.nn.Conv3d = model.patch_embedding + img_embedder = cls( + adapter, + orig_layer=orig_patch_embedding, + in_channels=new_channels, + ) + + # hijack the forward method + orig_patch_embedding._orig_i2v_adapter_forward = orig_patch_embedding.forward + orig_patch_embedding.forward = img_embedder.forward + + # update the config of the transformer, only needed when merged in + # model.config.in_channels = model.config.in_channels + new_channels + # model.config["in_channels"] = model.config.in_channels + new_channels + + return img_embedder + else: + raise ValueError("Model not supported") + + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + + if x.shape[1] > self.orig_layer_ref().in_channels: + # we have i2v, so we need to remove the extra channels + x = x[:, :self.orig_layer_ref().in_channels, :, :, :] + return self.orig_layer_ref()._orig_i2v_adapter_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + # x is arranged channels cat(orig_input = 16, temporal_conditioning_mask = 4, encoded_first_frame=16) + # (16 + 4 + 16) = 36 channels + # (batch_size, 36, num_frames, latent_height, latent_width) + + orig_device = x.device + orig_dtype = x.dtype + + orig_in = x[:, :16, :, :, :] + orig_out = self.orig_layer_ref()._orig_i2v_adapter_forward(orig_in) + + # remove original stuff + x = x[:, 16:, :, :, :] + + x = x.to(self.patch_embedding.weight.device, dtype=self.patch_embedding.weight.dtype) + + x = self.patch_embedding(x) + + x = x.to(orig_device, dtype=orig_dtype) + + # add the original out + x = x + orig_out + return x + + +def deactivatable_forward( + self: 'Attention', + *args, + **kwargs +): + if self._attn_hog_ref() is not None and self._attn_hog_ref().is_active: + self.added_kv_proj_dim = None + self.add_k_proj = self._add_k_proj + self.add_v_proj = self._add_v_proj + self.norm_added_q = self._norm_added_q + self.norm_added_k = self._norm_added_k + else: + self.added_kv_proj_dim = self._attn_hog_ref().added_kv_proj_dim + self.add_k_proj = None + self.add_v_proj = None + self.norm_added_q = None + self.norm_added_k = None + return self._orig_forward(*args, **kwargs) + + +class AttentionHog(torch.nn.Module): + def __init__( + self, + added_kv_proj_dim: int, + adapter: 'I2VAdapter', + attn_layer: Attention, + model: 'WanTransformer3DModel', + ): + super().__init__() + + # To prevent circular import. + from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.added_kv_proj_dim = added_kv_proj_dim + self.attn_layer_ref: weakref.ref = weakref.ref(attn_layer) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.model_ref: weakref.ref = weakref.ref(model) + + qk_norm = model.config.qk_norm + + # layers + self.add_k_proj = torch.nn.Linear( + added_kv_proj_dim, + attn_layer.inner_kv_dim, + bias=attn_layer.added_proj_bias + ) + self.add_k_proj.weight.data = self.add_k_proj.weight.data * 0.001 + self.add_v_proj = torch.nn.Linear( + added_kv_proj_dim, + attn_layer.inner_kv_dim, + bias=attn_layer.added_proj_bias + ) + self.add_v_proj.weight.data = self.add_v_proj.weight.data * 0.001 + + # do qk norm. It isnt stored in the class, but we can infer it from the attn layer + self.norm_added_q = None + self.norm_added_k = None + + if attn_layer.norm_q is not None: + eps: float = 1e-5 + if qk_norm == "layer_norm": + self.norm_added_q = torch.nn.LayerNorm( + attn_layer.norm_q.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_q.elementwise_affine) + self.norm_added_k = torch.nn.LayerNorm( + attn_layer.norm_k.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_k.elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm( + attn_layer.norm_q.normalized_shape, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm( + attn_layer.norm_k.normalized_shape, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) + self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # Wanx applies qk norm across all heads + self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) + self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + + # add these to the attn later in a way they can be deactivated + attn_layer._add_k_proj = self.add_k_proj + attn_layer._add_v_proj = self.add_v_proj + attn_layer._norm_added_q = self.norm_added_q + attn_layer._norm_added_k = self.norm_added_k + + # make it deactivateable + attn_layer._attn_hog_ref = weakref.ref(self) + attn_layer._orig_forward = attn_layer.forward + attn_layer.forward = partial(deactivatable_forward, attn_layer) + + def forward(self, *args, **kwargs): + if not self.adapter_ref().is_active: + return self.attn_module(*args, **kwargs) + + # TODO implement this + raise NotImplementedError("Attention hog not implemented") + + def is_active(self): + return self.adapter_ref().is_active + + +def new_wan_forward( + self: WanTransformer3DModel, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + # prevent circular import + from toolkit.models.wan21.wan_utils import add_first_frame_conditioning + adapter:'I2VAdapter' = self._i2v_adapter_ref() + + if adapter.is_active: + # activate the condition embedder + self.condition_embedder.image_embedder = adapter.image_embedder + + # for wan they are putting the image emcoder embeds on the unconditional + # this needs to be fixed as that wont work. For now, we will will use the embeds we have in order + # we cache an conditional and an unconditional embed. On sampling, it samples conditional first, + # then unconditional. So we just need to keep track of which one we are using. This is a horrible hack + # TODO find a not stupid way to do this. + + if adapter.adapter_ref().is_sampling: + if not hasattr(self, '_do_unconditional'): + # set it to true so we alternate to false immediatly + self._do_unconditional = True + + # alternate it + self._do_unconditional = not self._do_unconditional + if self._do_unconditional: + # slightly reduce strength of conditional for the unconditional + # encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5 + # shuffle the embedding tokens so we still have all the information, but it is scrambled + # this will prevent things like color from being cfg overweights, but still sharpen content. + + encoder_hidden_states_image = shuffle_tensor_along_axis( + adapter.adapter_ref().conditional_embeds, + axis=1 + ) + # encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds + else: + # use the conditional + encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds + else: + # doing a normal training run, always use conditional embeds + encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds + + # add the first frame conditioning + if adapter.frame_embedder is not None: + with torch.no_grad(): + # add the first frame conditioning + conditioning_frame = adapter.adapter_ref().cached_control_image_0_1 + if conditioning_frame is None: + raise ValueError("No conditioning frame found") + + # make it -1 to 1 + conditioning_frame = (conditioning_frame * 2) - 1 + conditioning_frame = conditioning_frame.to( + hidden_states.device, dtype=hidden_states.dtype + ) + + # if doing a full denoise, the latent input may be full channels here, only get first 16 + if hidden_states.shape[1] > 16: + hidden_states = hidden_states[:, :16, :, :, :] + + + hidden_states = add_first_frame_conditioning( + latent_model_input=hidden_states, + first_frame=conditioning_frame, + vae=adapter.adapter_ref().sd_ref().vae, + ) + else: + # not active deactivate the condition embedder + self.condition_embedder.image_embedder = None + + return self._orig_i2v_adapter_forward( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + ) + + +class I2VAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'BaseModel', + config: 'AdapterConfig', + train_config: 'TrainConfig', + image_processor: Union[SiglipImageProcessor, CLIPImageProcessor], + vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection], + ): + super().__init__() + # avoid circular import + from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0 + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.config = config + self.device_torch = sd.device_torch + self.control_lora = None + self.image_processor_ref: weakref.ref = weakref.ref(image_processor) + self.vision_encoder_ref: weakref.ref = weakref.ref(vision_encoder) + + ve_img_size = vision_encoder.config.image_size + ve_patch_size = vision_encoder.config.patch_size + num_patches = (ve_img_size // ve_patch_size) ** 2 + num_vision_tokens = num_patches + + # siglip does not have a class token + if not vision_encoder.__class__.__name__.lower().startswith("siglip"): + num_vision_tokens = num_patches + 1 + + model_class = sd.model.__class__.__name__ + + if self.network_config is not None: + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + network_kwargs['ignore_if_contains'] += [ + 'add_k_proj', + 'add_v_proj', + 'norm_added_q', + 'norm_added_k', + ] + if model_class == 'WanTransformer3DModel': + # always ignore patch_embedding + network_kwargs['ignore_if_contains'].append('patch_embedding') + + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() + + self.frame_embedder: FrameEmbedder = None + if self.config.i2v_do_start_frame: + self.frame_embedder = FrameEmbedder.from_model( + sd.unet, + self + ) + self.frame_embedder.to(self.device_torch) + + # hijack the blocks so we can inject our vision encoder + attn_hog_list = [] + if model_class == 'WanTransformer3DModel': + added_kv_proj_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim + # update the model so it can accept the new input + # wan has i2v with clip-h for i2v, additional k v attn that directly takes + # in the penultimate_hidden_states from the vision encoder + # the kv is on blocks[0].attn2 + sd.model.config.added_kv_proj_dim = added_kv_proj_dim + sd.model.config['added_kv_proj_dim'] = added_kv_proj_dim + + transformer: WanTransformer3DModel = sd.model + for block in transformer.blocks: + block.attn2.added_kv_proj_dim = added_kv_proj_dim + attn_module = AttentionHog( + added_kv_proj_dim, + self, + block.attn2, + transformer + ) + # set the attn function to ours that handles custom number of vision tokens + block.attn2.set_processor(WanAttnProcessor2_0(num_vision_tokens)) + + attn_hog_list.append(attn_module) + else: + raise ValueError(f"Model {model_class} not supported") + + self.attn_hog_list = torch.nn.ModuleList(attn_hog_list) + self.attn_hog_list.to(self.device_torch) + + inner_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim + image_embed_dim = vision_encoder.config.hidden_size + self.image_embedder = WanImageEmbedding(image_embed_dim, inner_dim) + + # override the forward method + if model_class == 'WanTransformer3DModel': + self.sd_ref().model._orig_i2v_adapter_forward = self.sd_ref().model.forward + self.sd_ref().model.forward = partial( + new_wan_forward, + self.sd_ref().model + ) + + # add the wan image embedder + self.sd_ref().model.condition_embedder._image_embedder = self.image_embedder + self.sd_ref().model.condition_embedder._image_embedder.to(self.device_torch) + + self.sd_ref().model._i2v_adapter_ref = weakref.ref(self) + + def get_params(self): + if self.control_lora is not None: + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + if self.frame_embedder is not None: + # make sure the embedder is float32 + self.frame_embedder.to(torch.float32) + params += list(self.frame_embedder.parameters()) + + # add the attn hogs + for attn_hog in self.attn_hog_list: + params += list(attn_hog.parameters()) + + # add the image embedder + if self.image_embedder is not None: + params += list(self.image_embedder.parameters()) + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + attn_hog_sd = {} + frame_embedder_sd = {} + image_embedder_sd = {} + + for key, value in state_dict.items(): + if "frame_embedder" in key: + new_key = key.replace("frame_embedder.", "") + frame_embedder_sd[new_key] = value + elif "attn_hog" in key: + new_key = key.replace("attn_hog.", "") + attn_hog_sd[new_key] = value + elif "image_embedder" in key: + new_key = key.replace("image_embedder.", "") + image_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading + if self.control_lora is not None: + self.control_lora.load_weights(lora_sd) + if self.frame_embedder is not None: + self.frame_embedder.load_state_dict( + frame_embedder_sd, strict=False) + self.attn_hog_list.load_state_dict( + attn_hog_sd, strict=False) + self.image_embedder.load_state_dict( + image_embedder_sd, strict=False) + + def get_state_dict(self): + if self.control_lora is not None: + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + + if self.frame_embedder is not None: + frame_embedder_sd = self.frame_embedder.state_dict() + for key, value in frame_embedder_sd.items(): + lora_sd[f"frame_embedder.{key}"] = value + + # add the attn hogs + attn_hog_sd = self.attn_hog_list.state_dict() + for key, value in attn_hog_sd.items(): + lora_sd[f"attn_hog.{key}"] = value + + # add the image embedder + image_embedder_sd = self.image_embedder.state_dict() + for key, value in image_embedder_sd.items(): + lora_sd[f"image_embedder.{key}"] = value + + return lora_sd + + def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): + # todo handle start frame + return latents + + def edit_batch_processed(self, batch: DataLoaderBatchDTO): + with torch.no_grad(): + # we will alway get a clip image frame, if one is not passed, use image + # or if video, pull from the first frame + # edit the batch to pull the first frame out of a video if we have it + # videos come in (bs, num_frames, channels, height, width) + tensor = batch.tensor + if batch.clip_image_tensor is None: + if len(tensor.shape) == 5: + # we have a video + first_frames = tensor[:, 0, :, :, :].clone() + else: + # we have a single image + first_frames = tensor.clone() + + # it is -1 to 1, change it to 0 to 1 + first_frames = (first_frames + 1) / 2 + + # clip image tensors are preprocessed. + tensors_0_1 = first_frames.to(dtype=torch.float16) + clip_out = self.adapter_ref().clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + batch.clip_image_tensor = clip_out.to(self.device_torch) + return batch + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/ai-toolkit/toolkit/models/ilora.py b/ai-toolkit/toolkit/models/ilora.py new file mode 100644 index 0000000000000000000000000000000000000000..886d263ca6a002a1593b6b8f1923324539a9dcf4 --- /dev/null +++ b/ai-toolkit/toolkit/models/ilora.py @@ -0,0 +1,358 @@ +import math +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.resampler import Resampler + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(dropout) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.dropout(x) + if self.use_residual: + x = x + residual + return x + +class LoRAGenerator(torch.nn.Module): + def __init__( + self, + input_size: int = 768, # projection dimension + hidden_size: int = 768, + head_size: int = 512, + num_heads: int = 1, + num_mlp_layers: int = 1, + output_size: int = 768, + dropout: float = 0.0 + ): + super().__init__() + self.input_size = input_size + self.num_heads = num_heads + self.simple = False + + self.output_size = output_size + + if self.simple: + self.head = nn.Linear(input_size, head_size, bias=False) + else: + self.lin_in = nn.Linear(input_size, hidden_size) + + self.mlp_blocks = nn.Sequential(*[ + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers) + ]) + self.head = nn.Linear(hidden_size, head_size, bias=False) + self.norm = nn.LayerNorm(head_size) + + if num_heads == 1: + self.output = nn.Linear(head_size, self.output_size) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + self.output.weight.data *= 0.01 + else: + head_output_size = output_size // num_heads + self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)]) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + for output in self.outputs: + output.weight.data *= 0.01 + + # allow get device + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, embedding): + if len(embedding.shape) == 2: + embedding = embedding.unsqueeze(1) + + x = embedding + + if not self.simple: + x = self.lin_in(embedding) + x = self.mlp_blocks(x) + x = self.head(x) + x = self.norm(x) + + if self.num_heads == 1: + x = self.output(x) + else: + out_chunks = torch.chunk(x, self.num_heads, dim=1) + x = [] + for out_layer, chunk in zip(self.outputs, out_chunks): + x.append(out_layer(chunk)) + x = torch.cat(x, dim=-1) + + return x.squeeze(1) + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + org_module = self.lora_module_ref().orig_module_ref() + stride = org_module.stride + padding = org_module.padding + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + + +class InstantLoRAModule(torch.nn.Module): + def __init__( + self, + vision_hidden_size: int, + vision_tokens: int, + head_dim: int, + num_heads: int, # number of heads in the resampler + sd: 'StableDiffusion', + config=None + ): + super(InstantLoRAModule, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + self.head_dim = head_dim + self.num_heads = num_heads + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + number_formatted_output_size = "{:,}".format(output_size) + + print(f" ILORA output size: {number_formatted_output_size}") + + # if not evenly divisible, error + if self.output_size % self.num_heads != 0: + raise ValueError("Output size must be divisible by the number of heads") + + self.head_output_size = self.output_size // self.num_heads + + if vision_tokens > 1: + self.resampler = Resampler( + dim=vision_hidden_size, + depth=4, + dim_head=64, + heads=12, + num_queries=num_heads, # output tokens + embedding_dim=vision_hidden_size, + max_seq_len=vision_tokens, + output_dim=head_dim, + apply_pos_emb=True, # this is new + ff_mult=4 + ) + + self.proj_module = LoRAGenerator( + input_size=head_dim, + hidden_size=head_dim, + head_size=head_dim, + num_mlp_layers=1, + num_heads=self.num_heads, + output_size=self.output_size, + ) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + } + diff --git a/ai-toolkit/toolkit/models/ilora2.py b/ai-toolkit/toolkit/models/ilora2.py new file mode 100644 index 0000000000000000000000000000000000000000..886d263ca6a002a1593b6b8f1923324539a9dcf4 --- /dev/null +++ b/ai-toolkit/toolkit/models/ilora2.py @@ -0,0 +1,358 @@ +import math +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.resampler import Resampler + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(dropout) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.dropout(x) + if self.use_residual: + x = x + residual + return x + +class LoRAGenerator(torch.nn.Module): + def __init__( + self, + input_size: int = 768, # projection dimension + hidden_size: int = 768, + head_size: int = 512, + num_heads: int = 1, + num_mlp_layers: int = 1, + output_size: int = 768, + dropout: float = 0.0 + ): + super().__init__() + self.input_size = input_size + self.num_heads = num_heads + self.simple = False + + self.output_size = output_size + + if self.simple: + self.head = nn.Linear(input_size, head_size, bias=False) + else: + self.lin_in = nn.Linear(input_size, hidden_size) + + self.mlp_blocks = nn.Sequential(*[ + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers) + ]) + self.head = nn.Linear(hidden_size, head_size, bias=False) + self.norm = nn.LayerNorm(head_size) + + if num_heads == 1: + self.output = nn.Linear(head_size, self.output_size) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + self.output.weight.data *= 0.01 + else: + head_output_size = output_size // num_heads + self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)]) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + for output in self.outputs: + output.weight.data *= 0.01 + + # allow get device + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def forward(self, embedding): + if len(embedding.shape) == 2: + embedding = embedding.unsqueeze(1) + + x = embedding + + if not self.simple: + x = self.lin_in(embedding) + x = self.mlp_blocks(x) + x = self.head(x) + x = self.norm(x) + + if self.num_heads == 1: + x = self.output(x) + else: + out_chunks = torch.chunk(x, self.num_heads, dim=1) + x = [] + for out_layer, chunk in zip(self.outputs, out_chunks): + x.append(out_layer(chunk)) + x = torch.cat(x, dim=-1) + + return x.squeeze(1) + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + org_module = self.lora_module_ref().orig_module_ref() + stride = org_module.stride + padding = org_module.padding + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + + +class InstantLoRAModule(torch.nn.Module): + def __init__( + self, + vision_hidden_size: int, + vision_tokens: int, + head_dim: int, + num_heads: int, # number of heads in the resampler + sd: 'StableDiffusion', + config=None + ): + super(InstantLoRAModule, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + self.head_dim = head_dim + self.num_heads = num_heads + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + number_formatted_output_size = "{:,}".format(output_size) + + print(f" ILORA output size: {number_formatted_output_size}") + + # if not evenly divisible, error + if self.output_size % self.num_heads != 0: + raise ValueError("Output size must be divisible by the number of heads") + + self.head_output_size = self.output_size // self.num_heads + + if vision_tokens > 1: + self.resampler = Resampler( + dim=vision_hidden_size, + depth=4, + dim_head=64, + heads=12, + num_queries=num_heads, # output tokens + embedding_dim=vision_hidden_size, + max_seq_len=vision_tokens, + output_dim=head_dim, + apply_pos_emb=True, # this is new + ff_mult=4 + ) + + self.proj_module = LoRAGenerator( + input_size=head_dim, + hidden_size=head_dim, + head_size=head_dim, + num_mlp_layers=1, + num_heads=self.num_heads, + output_size=self.output_size, + ) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + } + diff --git a/ai-toolkit/toolkit/models/llm_adapter.py b/ai-toolkit/toolkit/models/llm_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..9c098f90f8c7b5763ea930aab12adcc2e0dddc74 --- /dev/null +++ b/ai-toolkit/toolkit/models/llm_adapter.py @@ -0,0 +1,191 @@ +from functools import partial +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING + +from diffusers.models.transformers.transformer_flux import FluxTransformerBlock +from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer + +from toolkit import train_tools +from toolkit.prompt_utils import PromptEmbeds +from diffusers import Transformer2DModel +from toolkit.dequantize import patch_dequantization_on_save + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline + from toolkit.custom_adapter import CustomAdapter + +LLM = Union[Qwen2Model, LlamaModel] +LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer] + + +def new_context_embedder_forward(self, x): + if self._adapter_ref().is_active: + x = self._context_embedder_ref()(x) + else: + x = self._orig_forward(x) + return x + +def new_block_forward( + self: FluxTransformerBlock, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if self._adapter_ref().is_active: + return self._new_block_ref()(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) + else: + return self._orig_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) + + +class LLMAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + llm: LLM, + tokenizer: LLMTokenizer, + num_cloned_blocks: int = 0, + ): + super(LLMAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.llm_ref: weakref.ref = weakref.ref(llm) + self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) + self.num_cloned_blocks = num_cloned_blocks + self.apply_embedding_mask = False + # make sure we can pad + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.system_prompt = "" + # self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " + + # determine length of system prompt + sys_prompt_tokenized = tokenizer( + [self.system_prompt], + padding="longest", + return_tensors="pt", + ) + + sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0] + + self.system_prompt_length = sys_prompt_tokenized_ids.shape[0] + + print(f"System prompt length: {self.system_prompt_length}") + + self.hidden_size = llm.config.hidden_size + + blocks = [] + + if sd.is_flux: + self.apply_embedding_mask = True + self.context_embedder = nn.Linear( + self.hidden_size, sd.unet.inner_dim) + self.sequence_length = 512 + sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward + sd.unet.context_embedder.forward = partial( + new_context_embedder_forward, sd.unet.context_embedder) + sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder) + # add a is active property to the context embedder + sd.unet.context_embedder._adapter_ref = self.adapter_ref + + for idx in range(self.num_cloned_blocks): + block = FluxTransformerBlock( + dim=sd.unet.inner_dim, + num_attention_heads=24, + attention_head_dim=128, + ) + # patch it in case it is quantized + patch_dequantization_on_save(sd.unet.transformer_blocks[idx]) + state_dict = sd.unet.transformer_blocks[idx].state_dict() + for key, value in state_dict.items(): + block.state_dict()[key].copy_(value) + blocks.append(block) + orig_block = sd.unet.transformer_blocks[idx] + orig_block._orig_forward = orig_block.forward + orig_block.forward = partial( + new_block_forward, orig_block) + orig_block._new_block_ref = weakref.ref(block) + orig_block._adapter_ref = self.adapter_ref + + elif sd.is_lumina2: + self.context_embedder = nn.Linear( + self.hidden_size, sd.unet.hidden_size) + self.sequence_length = 256 + else: + raise ValueError( + "llm adapter currently only supports flux or lumina2") + + self.blocks = nn.ModuleList(blocks) + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]], + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tokenizer = self.tokenizer_ref() + text_encoder = self.llm_ref() + device = text_encoder.device + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length + self.system_prompt_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) + + # remove the system prompt from the input and attention mask + + prompt_embeds = text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-1] + + prompt_embeds = prompt_embeds[:, self.system_prompt_length:] + prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] + + dtype = text_encoder.dtype + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # make a getter to see if is active + + @property + def is_active(self): + return self.adapter_ref().is_active + + def encode_text(self, prompt): + + prompt = prompt if isinstance(prompt, list) else [prompt] + + prompt = [self.system_prompt + p for p in prompt] + # prompt = [self.system_prompt + p for p in prompt] + + prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( + prompt=prompt, + max_sequence_length=self.sequence_length, + ) + + prompt_embeds = PromptEmbeds( + prompt_embeds, + attention_mask=prompt_attention_mask, + ).detach() + + return prompt_embeds + + def forward(self, input): + return input diff --git a/ai-toolkit/toolkit/models/loaders/__init__.py b/ai-toolkit/toolkit/models/loaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/toolkit/models/loaders/umt5.py b/ai-toolkit/toolkit/models/loaders/umt5.py new file mode 100644 index 0000000000000000000000000000000000000000..f53db3cb9003367783669c2847835dad8fe1fd94 --- /dev/null +++ b/ai-toolkit/toolkit/models/loaders/umt5.py @@ -0,0 +1,46 @@ +from typing import List +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +class PatchedT5Tokenizer(T5Tokenizer): + def __init__( + self, + vocab: str | list[tuple[str, float]] | None = None, + eos_token="", + unk_token="", + pad_token="", + _spm_precompiled_charsmap=None, + extra_ids=100, + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab=vocab, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + _spm_precompiled_charsmap=None, # this is passing a empty byte string for some reason now. + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + +def get_umt5_encoder( + model_path: str, + tokenizer_subfolder: str = None, + encoder_subfolder: str = None, + torch_dtype: str = torch.bfloat16, + comfy_files: List[str] = [ + "text_encoders/umt5_xxl_fp16.safetensors", + "text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", + ], +) -> UMT5EncoderModel: + """ + Load the UMT5 encoder model from the specified path. + """ + tokenizer = PatchedT5Tokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder) + print(f"Using {model_path} for UMT5 encoder.") + text_encoder = UMT5EncoderModel.from_pretrained( + model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype + ) + return tokenizer, text_encoder diff --git a/ai-toolkit/toolkit/models/lokr.py b/ai-toolkit/toolkit/models/lokr.py new file mode 100644 index 0000000000000000000000000000000000000000..74de6350d185cc1c4f0e884a6aa5a2f7588fc90e --- /dev/null +++ b/ai-toolkit/toolkit/models/lokr.py @@ -0,0 +1,348 @@ +# based heavily on https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from toolkit.network_mixins import ToolkitModuleMixin + +from typing import TYPE_CHECKING, Union, List + +from optimum.quanto import QBytesTensor, QTensor +from torchao.dtypes import AffineQuantizedTensor + +if TYPE_CHECKING: + + from toolkit.lora_special import LoRASpecialNetwork + + +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 + 128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8 + 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 + 360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8 + 512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16 + 1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + return m, n + if factor == -1: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + + +def make_weight_cp(t, wa, wb): + rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', + t, wa, wb) # [c, d, k1, k2] + return rebuild2 + + +def make_kron(w1, w2, scale): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + rebuild = torch.kron(w1, w2) + + return rebuild*scale + + +class LokrModule(ToolkitModuleMixin, nn.Module): + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0., + rank_dropout=0., + module_dropout=0., + use_cp=False, + decompose_both=False, + network: 'LoRASpecialNetwork' = None, + factor: int = -1, # factorization factor + **kwargs, + ): + """ if alpha == 0 or None, alpha is rank (no scaling). """ + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) + factor = int(factor) + self.lora_name = lora_name + self.lora_dim = lora_dim + self.cp = False + self.use_w1 = False + self.use_w2 = False + self.can_merge_in = True + + self.shape = org_module.weight.shape + if org_module.__class__.__name__ == 'Conv2d': + in_dim = org_module.in_channels + k_size = org_module.kernel_size + out_dim = org_module.out_channels + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + # ((a, b), (c, d), *k_size) + shape = ((out_l, out_k), (in_m, in_n), *k_size) + + self.cp = use_cp and k_size != (1, 1) + if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: + self.lokr_w1_a = nn.Parameter( + torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter( + torch.empty(lora_dim, shape[1][0])) + else: + self.use_w1 = True + self.lokr_w1 = nn.Parameter(torch.empty( + shape[0][0], shape[1][0])) # a*c, 1-mode + + if lora_dim >= max(shape[0][1], shape[1][1])/2: + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty( + shape[0][1], shape[1][1], *k_size)) + elif self.cp: + self.lokr_t2 = nn.Parameter(torch.empty( + lora_dim, lora_dim, shape[2], shape[3])) + self.lokr_w2_a = nn.Parameter( + torch.empty(lora_dim, shape[0][1])) # b, 1-mode + self.lokr_w2_b = nn.Parameter( + torch.empty(lora_dim, shape[1][1])) # d, 2-mode + else: # Conv2d not cp + # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] + self.lokr_w2_a = nn.Parameter( + torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty( + lora_dim, shape[1][1]*shape[2]*shape[3])) + # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) + + self.op = F.conv2d + self.extra_args = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups + } + + else: # Linear + in_dim = org_module.in_features + out_dim = org_module.out_features + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + shape = ((out_l, out_k), (in_m, in_n)) + + # smaller part. weight scale + if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: + self.lokr_w1_a = nn.Parameter( + torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter( + torch.empty(lora_dim, shape[1][0])) + else: + self.use_w1 = True + self.lokr_w1 = nn.Parameter(torch.empty( + shape[0][0], shape[1][0])) # a*c, 1-mode + + if lora_dim < max(shape[0][1], shape[1][1])/2: + # bigger part. weight and LoRA. [b, dim] x [dim, d] + self.lokr_w2_a = nn.Parameter( + torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter( + torch.empty(lora_dim, shape[1][1])) + # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) + else: + self.use_w2 = True + self.lokr_w2 = nn.Parameter( + torch.empty(shape[0][1], shape[1][1])) + + self.op = F.linear + self.extra_args = {} + + self.dropout = dropout + if dropout: + print("[WARN]LoKr haven't implemented normal dropout yet.") + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if isinstance(alpha, torch.Tensor): + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + if self.use_w2 and self.use_w1: + # use scale = 1 + alpha = lora_dim + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant + + if self.use_w2: + torch.nn.init.constant_(self.lokr_w2, 0) + else: + if self.cp: + torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) + torch.nn.init.constant_(self.lokr_w2_b, 0) + + if self.use_w1: + torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) + else: + torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5)) + + self.multiplier = multiplier + self.org_module = [org_module] + weight = make_kron( + self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b, + (self.lokr_w2 if self.use_w2 + else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp + else self.lokr_w2_a@self.lokr_w2_b), + torch.tensor(self.multiplier * self.scale) + ) + assert torch.sum(torch.isnan(weight)) == 0, "weight is nan" + + # Same as locon.py + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, orig_weight=None): + weight = make_kron( + self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b, + (self.lokr_w2 if self.use_w2 + else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp + else self.lokr_w2_a@self.lokr_w2_b), + torch.tensor(self.scale) + ) + if orig_weight is not None: + weight = weight.reshape(orig_weight.shape) + if self.training and self.rank_dropout: + drop = torch.rand(weight.size(0)) < self.rank_dropout + weight *= drop.view(-1, [1] * + len(weight.shape[1:])).to(weight.device) + return weight + + @torch.no_grad() + def merge_in(self, merge_weight=1.0): + if not self.can_merge_in: + return + + # extract weight from org_module + org_sd = self.org_module[0].state_dict() + # todo find a way to merge in weights when doing quanto quantized model + if 'weight._data' in org_sd: + # quanto quantized weight + return + + weight_key = "weight" + from toolkit.util.quantize import is_quantized_tensor + org_weight = self.org_module[0].weight + is_ao_quantized = is_quantized_tensor(org_weight) + orig_dtype = org_weight.dtype + # dequantize torchao weights so the delta can be merged in full precision + weight = (org_weight.dequantize() if is_ao_quantized else org_weight).float() + + scale = self.scale + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + lokr_weight = self.get_weight(weight) + + merged_weight = ( + weight + + (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype) + ) + + # write the merged weight back, re-quantizing if the original was torchao quantized so the + # model stays quantized across continuous merge/reset cycles + if is_ao_quantized: + from toolkit.util.quantize import get_torchao_config, requantize_module_weight + requantize_module_weight( + self.org_module[0], merged_weight, orig_dtype, get_torchao_config(self._get_base_qtype()) + ) + else: + org_sd[weight_key] = merged_weight.to(orig_dtype) + self.org_module[0].load_state_dict(org_sd) + + def get_orig_weight(self, device): + weight = self.org_module[0].weight + if weight.device != device: + weight = weight.to(device) + if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): + return weight.dequantize().data.detach() + elif isinstance(weight, AffineQuantizedTensor): + return weight.dequantize().data.detach() + else: + return weight.data.detach() + + def get_orig_bias(self, device): + if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: + bias = self.org_module[0].bias + if bias.device != device: + bias = bias.to(device) + if isinstance(bias, QTensor) or isinstance(bias, QBytesTensor): + return bias.dequantize().data.detach() + elif isinstance(bias, AffineQuantizedTensor): + return bias.dequantize().data.detach() + else: + return self.org_module[0].bias.data.detach() + return None + + def _call_forward(self, x): + if isinstance(x, QTensor) or isinstance(x, QBytesTensor): + x = x.dequantize() + + orig_dtype = x.dtype + + orig_weight = self.get_orig_weight(x.device) + lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype) + multiplier = self.network_ref().torch_multiplier + + if x.dtype != orig_weight.dtype: + x = x.to(dtype=orig_weight.dtype) + + # we do not currently support split batch multipliers for lokr. Just do a mean + multiplier = torch.mean(multiplier) + + weight = ( + orig_weight + + lokr_weight * multiplier + ) + bias = self.get_orig_bias(x.device) + if bias is not None: + bias = bias.to(weight.device, dtype=weight.dtype) + output = self.op( + x, + weight.view(self.shape), + bias, + **self.extra_args + ) + return output.to(orig_dtype) diff --git a/ai-toolkit/toolkit/models/mean_flow_adapter.py b/ai-toolkit/toolkit/models/mean_flow_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb182c42118c4924af3714228a7ac3de27b1618 --- /dev/null +++ b/ai-toolkit/toolkit/models/mean_flow_adapter.py @@ -0,0 +1,324 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING, Tuple +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +from diffusers.models.embeddings import ( + CombinedTimestepTextProjEmbeddings, + CombinedTimestepGuidanceTextProjEmbeddings, +) +from functools import partial + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel + + +def mean_flow_time_text_embed_forward( + self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection +): + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + # make zero timestep ending if none is passed + if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat( + [timestep, torch.zeros_like(timestep)], dim=0 + ) # timestep - 0 (final timestep) == same as start timestep + + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = timesteps_emb.dtype + timesteps_emb = timesteps_emb.to(torch.float32) + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) + ) + timesteps_emb = timesteps_emb.to(orig_dtype) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +def mean_flow_time_text_guidance_embed_forward( + self: CombinedTimestepGuidanceTextProjEmbeddings, + timestep, + guidance, + pooled_projection, +): + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + # make zero timestep ending if none is passed + if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat( + [timestep, torch.ones_like(timestep)], dim=0 + ) # timestep - 0 (final timestep) == same as start timestep + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + guidance_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = timesteps_emb.dtype + timesteps_emb = timesteps_emb.to(torch.float32) + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) + ) + timesteps_emb = timesteps_emb.to(orig_dtype) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +def convert_flux_to_mean_flow( + transformer: FluxTransformer2DModel, +): + if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_embed_forward, transformer.time_text_embed + ) + elif isinstance( + transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings + ): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed + ) + else: + raise ValueError( + "Unsupported time_text_embed type: {}".format( + type(transformer.time_text_embed) + ) + ) + +def mean_flow_omnigen2_time_text_embed_forward( + self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]: + timestep = torch.cat( + [timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps + ) + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = time_embed.dtype + time_embed = time_embed.to(torch.float32) + time_embed_start, time_embed_end = time_embed.chunk(2, dim=0) + time_embed = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([time_embed_start, time_embed_end], dim=-1) + ) + time_embed = time_embed.to(orig_dtype) + + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed + + +def convert_omnigen2_to_mean_flow( + transformer: 'OmniGen2Transformer2DModel', +): + transformer.time_caption_embed.forward = partial( + mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed + ) + +class MeanFlowAdapter(torch.nn.Module): + def __init__( + self, + adapter: "CustomAdapter", + sd: "StableDiffusion", + config: "AdapterConfig", + train_config: "TrainConfig", + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.device_torch = sd.device_torch + self.lora = None + + if self.network_config is not None: + network_kwargs = ( + {} + if self.network_config.network_kwargs is None + else self.network_config.network_kwargs + ) + if hasattr(sd, "target_lora_modules"): + network_kwargs["target_lin_modules"] = sd.target_lora_modules + + if "ignore_if_contains" not in network_kwargs: + network_kwargs["ignore_if_contains"] = [] + + self.lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs, + ) + self.lora.force_to(self.device_torch, dtype=torch.float32) + self.lora._update_torch_multiplier() + self.lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet, + ) + self.lora.can_merge_in = False + self.lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.lora.enable_gradient_checkpointing() + + emb_dim = None + if self.model_config.arch in ["flux", "flex2", "flex2"]: + transformer: FluxTransformer2DModel = sd.unet + emb_dim = ( + transformer.config.num_attention_heads + * transformer.config.attention_head_dim + ) + convert_flux_to_mean_flow(transformer) + + elif self.model_config.arch in ["omnigen2"]: + transformer: 'OmniGen2Transformer2DModel' = sd.unet + emb_dim = ( + 1024 + ) + convert_omnigen2_to_mean_flow(transformer) + else: + raise ValueError(f"Unsupported architecture: {self.model_config.arch}") + + self.mean_flow_timestep_embedder = torch.nn.Linear( + emb_dim * 2, + emb_dim, + ) + + # make the model function as before adding this adapter by initializing the weights + with torch.no_grad(): + self.mean_flow_timestep_embedder.weight.zero_() + self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim) + self.mean_flow_timestep_embedder.bias.zero_() + + self.mean_flow_timestep_embedder.to(self.device_torch) + + # add our adapter as a weak ref + if self.model_config.arch in ["flux", "flex2", "flex2"]: + sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self) + elif self.model_config.arch in ["omnigen2"]: + sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self) + + def get_params(self): + if self.lora is not None: + config = { + "text_encoder_lr": self.train_config.lr, + "unet_lr": self.train_config.lr, + } + sig = inspect.signature(self.lora.prepare_optimizer_params) + if "default_lr" in sig.parameters: + config["default_lr"] = self.train_config.lr + if "learning_rate" in sig.parameters: + config["learning_rate"] = self.train_config.lr + params_net = self.lora.prepare_optimizer_params(**config) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.mean_flow_timestep_embedder.to(torch.float32) + self.mean_flow_timestep_embedder.requires_grad = True + self.mean_flow_timestep_embedder.train() + + params += list(self.mean_flow_timestep_embedder.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + mean_flow_embedder_sd = {} + for key, value in state_dict.items(): + if "mean_flow_timestep_embedder" in key: + new_key = key.replace("transformer.mean_flow_timestep_embedder.", "") + mean_flow_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading for models that need it + if self.lora is not None: + self.lora.load_weights(lora_sd) + self.mean_flow_timestep_embedder.load_state_dict( + mean_flow_embedder_sd, strict=False + ) + + def get_state_dict(self): + if self.lora is not None: + lora_sd = self.lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + # todo make sure we match loras elseware. + mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict() + for key, value in mean_flow_embedder_sd.items(): + lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/ai-toolkit/toolkit/models/pixtral_vision.py b/ai-toolkit/toolkit/models/pixtral_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..815f33101ffb4cfc672cefde6f94fd01ae198783 --- /dev/null +++ b/ai-toolkit/toolkit/models/pixtral_vision.py @@ -0,0 +1,618 @@ +import math +from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING +import os +import torch +import torch.nn as nn +from dataclasses import dataclass +from huggingface_hub import snapshot_download +from safetensors.torch import load_file +import json + +if TYPE_CHECKING: + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int, **kwargs): + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # type: ignore + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: + keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) + values = torch.repeat_interleave(values, repeats=repeats, dim=dim) + return keys, values + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[:, None, :] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + head_dim: int, + n_kv_heads: int, + **kwargs, + ): + super().__init__() + + self.n_heads: int = n_heads + self.head_dim: int = head_dim + self.n_kv_heads: int = n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.head_dim ** -0.5 + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + cache: Optional[Any] = None, + mask: Optional['BlockDiagonalMask'] = None, + ) -> torch.Tensor: + from xformers.ops.fmha import memory_efficient_attention + assert mask is None or cache is None + seqlen_sum, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) + xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + if cache is None: + key, val = xk, xv + elif cache.prefill: + key, val = cache.interleave_kv(xk, xv) + cache.update(xk, xv) + else: + cache.update(xk, xv) + key, val = cache.key, cache.value + key = key.view(seqlen_sum * cache.max_seq_len, + self.n_kv_heads, self.head_dim) + val = val.view(seqlen_sum * cache.max_seq_len, + self.n_kv_heads, self.head_dim) + + # Repeat keys and values to match number of query heads + key, val = repeat_kv(key, val, self.repeats, dim=1) + + # xformers requires (B=1, S, H, D) + xq, key, val = xq[None, ...], key[None, ...], val[None, ...] + output = memory_efficient_attention( + xq, key, val, mask if cache is None else cache.mask) + output = output.view(seqlen_sum, self.n_heads * self.head_dim) + + assert isinstance(output, torch.Tensor) + + return self.wo(output) # type: ignore + + +class TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + norm_eps: float, + **kwargs, + ): + super().__init__() + self.n_heads = n_heads + self.dim = dim + self.attention = Attention( + dim=dim, + n_heads=n_heads, + head_dim=head_dim, + n_kv_heads=n_kv_heads, + ) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + self.feed_forward: nn.Module + self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + cache: Optional[Any] = None, + mask: Optional['BlockDiagonalMask'] = None, + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float = 1e4 # for rope-2D + image_token_id: int = 10 + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by + (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) + return positions + + +class PixtralVisionEncoder(nn.Module): + def __init__( + self, + hidden_size: int = 1024, + num_channels: int = 3, + image_size: int = 1024, + patch_size: int = 16, + intermediate_size: int = 4096, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + rope_theta: float = 1e4, # for rope-2D + image_token_id: int = 10, + **kwargs, + ): + super().__init__() + self.args = VisionEncoderArgs( + hidden_size=hidden_size, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + rope_theta=rope_theta, + image_token_id=image_token_id, + ) + args = self.args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = VisionTransformerBlocks(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder': + if os.path.isdir(pretrained_model_name_or_path): + model_folder = pretrained_model_name_or_path + else: + model_folder = snapshot_download(pretrained_model_name_or_path) + + # make sure there is a config + if not os.path.exists(os.path.join(model_folder, "config.json")): + raise ValueError(f"Could not find config.json in {model_folder}") + + # load config + with open(os.path.join(model_folder, "config.json"), "r") as f: + config = json.load(f) + + model = cls(**config) + + # see if there is a state_dict + if os.path.exists(os.path.join(model_folder, "model.safetensors")): + state_dict = load_file(os.path.join( + model_folder, "model.safetensors")) + model.load_state_dict(state_dict) + + return model + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: List[torch.Tensor], + ) -> torch.Tensor: + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + """ + Args: + images: list of N_img images of variable sizes, each of shape (C, H, W) + + Returns: + image_features: tensor of token features for all tokens of all images of + shape (N_toks, D) + """ + assert isinstance( + images, list), f"Expected list of images, got {type(images)}" + assert all(len(img.shape) == 3 for img in + images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}" + # pass images through initial convolution independently + patch_embeds_list = [self.patch_conv( + img.unsqueeze(0)).squeeze(0) for img in images] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).permute(1, 0) + for p in patch_embeds_list], dim=0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + mask = BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) + out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + # remove batch dimension of the single sequence + return out # type: ignore[no-any-return] + + +class VisionLanguageAdapter(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.w_in = nn.Linear( + in_dim, + out_dim, + bias=True, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(out_dim, out_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # type: ignore[no-any-return] + return self.w_out(self.gelu(self.w_in(x))) + + +class VisionTransformerBlocks(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append( + TransformerBlock( + dim=args.hidden_size, + hidden_dim=args.intermediate_size, + n_heads=args.num_attention_heads, + n_kv_heads=args.num_attention_heads, + head_dim=args.hidden_size // args.num_attention_heads, + norm_eps=1e-5, + ) + ) + + def forward( + self, + x: torch.Tensor, + mask: 'BlockDiagonalMask', + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + +DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB +DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB + + +def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """ + Normalize a tensor image with mean and standard deviation. + + Args: + image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1]. + mean (torch.Tensor): Mean for each channel. + std (torch.Tensor): Standard deviation for each channel. + + Returns: + torch.Tensor: Normalized image with shape (C, H, W). + """ + assert image.shape[0] == len(mean) == len( + std), f"{image.shape=}, {mean.shape=}, {std.shape=}" + + # Reshape mean and std to (C, 1, 1) for broadcasting + mean = mean.view(-1, 1, 1) + std = std.view(-1, 1, 1) + + return (image - mean) / std + + +def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor: + """ + Resize and normalize the input image. + + Args: + image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1]. + new_size (tuple[int, int]): Target size (height, width) for resizing. + + Returns: + torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W). + """ + # Resize the image + resized_image = torch.nn.functional.interpolate( + image.unsqueeze(0), + size=new_size, + mode='bicubic', + align_corners=False + ).squeeze(0) + + # Normalize the image + normalized_image = normalize( + resized_image, + torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype), + torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype) + ) + + return normalized_image + + +class PixtralVisionImagePreprocessor: + def __init__(self, image_patch_size=16, max_image_size=1024) -> None: + self.image_patch_size = image_patch_size + self.max_image_size = max_image_size + self.image_token = 10 + + def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]: + w: Union[int, float] + h: Union[int, float] + + if max_image_size is None: + max_image_size = self.max_image_size + + w, h = img.shape[-1], img.shape[-2] + + # originally, pixtral used the largest of the 2 dimensions, but we + # will use the base size of the image based on number of pixels. + # ratio = max(h / self.max_image_size, w / self.max_image_size) # original + + base_size = int(math.sqrt(w * h)) + ratio = base_size / max_image_size + if ratio > 1: + w = round(w / ratio) + h = round(h / ratio) + + width_tokens = (w - 1) // self.image_patch_size + 1 + height_tokens = (h - 1) // self.image_patch_size + 1 + + return width_tokens, height_tokens + + def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor: + """ + Converts ImageChunks to numpy image arrays and image token ids + + Args: + image torch tensor with values 0-1 and shape of (C, H, W) + + Returns: + processed_image: tensor of token features for all tokens of all images of + """ + # should not have batch + if len(image.shape) == 4: + raise ValueError( + f"Expected image with shape (C, H, W), got {image.shape}") + + if image.min() < 0.0 or image.max() > 1.0: + raise ValueError( + f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") + + if max_image_size is None: + max_image_size = self.max_image_size + + w, h = self._image_to_num_tokens(image, max_image_size=max_image_size) + assert w > 0 + assert h > 0 + + new_image_size = ( + w * self.image_patch_size, + h * self.image_patch_size, + ) + + processed_image = transform_image(image, new_image_size) + + return processed_image + + +class PixtralVisionImagePreprocessorCompatibleReturn: + def __init__(self, pixel_values) -> None: + self.pixel_values = pixel_values + + +# Compatable version with ai toolkit flow +class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor): + def __init__(self, image_patch_size=16, max_image_size=1024) -> None: + super().__init__( + image_patch_size=image_patch_size, + max_image_size=max_image_size + ) + self.size = { + 'height': max_image_size, + 'width': max_image_size + } + self.max_image_size = max_image_size + self.image_mean = DATASET_MEAN + self.image_std = DATASET_STD + + def __call__( + self, + images, + return_tensors="pt", + do_resize=True, + do_rescale=False, + max_image_size=None, + ) -> torch.Tensor: + if max_image_size is None: + max_image_size = self.max_image_size + out_stack = [] + if len(images.shape) == 3: + images = images.unsqueeze(0) + for i in range(images.shape[0]): + image = images[i] + processed_image = super().__call__(image, max_image_size=max_image_size) + out_stack.append(processed_image) + + output = torch.stack(out_stack, dim=0) + return PixtralVisionImagePreprocessorCompatibleReturn(output) + + +class PixtralVisionEncoderCompatibleReturn: + def __init__(self, hidden_states) -> None: + self.hidden_states = hidden_states + + +class PixtralVisionEncoderCompatibleConfig: + def __init__(self): + self.image_size = 1024 + self.hidden_size = 1024 + self.patch_size = 16 + + +class PixtralVisionEncoderCompatible(PixtralVisionEncoder): + def __init__( + self, + hidden_size: int = 1024, + num_channels: int = 3, + image_size: int = 1024, + patch_size: int = 16, + intermediate_size: int = 4096, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + rope_theta: float = 1e4, # for rope-2D + image_token_id: int = 10, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + rope_theta=rope_theta, + image_token_id=image_token_id, + ) + self.config = PixtralVisionEncoderCompatibleConfig() + + def forward( + self, + images, + output_hidden_states=True, + ) -> torch.Tensor: + out_stack = [] + if len(images.shape) == 3: + images = images.unsqueeze(0) + for i in range(images.shape[0]): + image = images[i] + # must be in an array + image_output = super().forward([image]) + out_stack.append(image_output) + + output = torch.stack(out_stack, dim=0) + return PixtralVisionEncoderCompatibleReturn([output]) diff --git a/ai-toolkit/toolkit/models/redux.py b/ai-toolkit/toolkit/models/redux.py new file mode 100644 index 0000000000000000000000000000000000000000..609ac50ae7f1404cfd85c63532339fcf94ae60c3 --- /dev/null +++ b/ai-toolkit/toolkit/models/redux.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = nn.Linear( + txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + x = self.redux_up(sigclip_embeds) + x = torch.nn.functional.silu(x) + + projected_x = self.redux_down(x) + return projected_x diff --git a/ai-toolkit/toolkit/models/sapiens2.py b/ai-toolkit/toolkit/models/sapiens2.py new file mode 100644 index 0000000000000000000000000000000000000000..9f4fb28a9be4ba9313b80710fcce125da9e93242 --- /dev/null +++ b/ai-toolkit/toolkit/models/sapiens2.py @@ -0,0 +1,1132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified for AI Toolkit by Ostris +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# https://raw.githubusercontent.com/facebookresearch/sapiens2/refs/heads/main/sapiens/backbones/standalone/sapiens2.py + +import math +from typing import Any, List, Literal, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint +from toolkit.paths import MODELS_PATH +import os + + +# ---------------------------------------------------------------------------- +def to_2tuple(x): + if isinstance(x, (str, bytes)): + return (x, x) + if isinstance(x, Sequence): + x = tuple(x) + if len(x) == 2: + return x + raise ValueError("Expected scalar or length-2 iterable") + return (x, x) + + +class RopePositionEmbedding(nn.Module): + def __init__( + self, + embed_dim: int, + *, + num_heads: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + super().__init__() + assert embed_dim % (4 * num_heads) == 0 + both_periods = min_period is not None and max_period is not None + if (base is None and not both_periods) or (base is not None and both_periods): + raise ValueError( + "Either `base` or `min_period`+`max_period` must be provided." + ) + + D_head = embed_dim // num_heads + self.base = base + self.min_period = min_period + self.max_period = max_period + self.D_head = D_head + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + + # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher + self.dtype = dtype or torch.float32 # Don't rely on self.periods.dtype + self.register_buffer( + "periods", + torch.empty(D_head // 4, device=device, dtype=self.dtype), + persistent=True, + ) + self._init_weights() + + def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]: + device = self.periods.device + dtype = self.dtype + dd = {"device": device, "dtype": dtype} + # Prepare coords in range [-1, +1] + if self.normalize_coords == "max": + max_HW = max(H, W) + coords_h = torch.arange(0.5, H, **dd) / max_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / max_HW # [W] + elif self.normalize_coords == "min": + min_HW = min(H, W) + coords_h = torch.arange(0.5, H, **dd) / min_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / min_HW # [W] + elif self.normalize_coords == "separate": + coords_h = torch.arange(0.5, H, **dd) / H # [H] + coords_w = torch.arange(0.5, W, **dd) / W # [W] + else: + raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") + coords = torch.stack( + torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1 + ) # [H, W, 2] + coords = coords.flatten(0, 1) # [HW, 2] + coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] + + # Shift coords by adding a uniform value in [-shift, shift] + if self.training and self.shift_coords is not None: + shift_hw = torch.empty(2, **dd).uniform_( + -self.shift_coords, self.shift_coords + ) + coords += shift_hw[None, :] + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if self.training and self.jitter_coords is not None: + jitter_max = np.log(self.jitter_coords) + jitter_min = -jitter_max + jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() + coords *= jitter_hw[None, :] + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if self.training and self.rescale_coords is not None: + rescale_max = np.log(self.rescale_coords) + rescale_min = -rescale_max + rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() + coords *= rescale_hw + + # Prepare angles and sin/cos + angles = ( + 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] + ) # [HW, 2, D//4] + angles = angles.flatten(1, 2) # [HW, D//2] + angles = angles.tile(2) # [HW, D] + cos = torch.cos(angles) # [HW, D] + sin = torch.sin(angles) # [HW, D] + + return (sin, cos) # 2 * [HW, D] + + def _init_weights(self): + device = self.periods.device + dtype = self.dtype + if self.base is not None: + periods = self.base ** ( + 2 + * torch.arange(self.D_head // 4, device=device, dtype=dtype) + / (self.D_head // 2) + ) # [D//4] + else: + base = self.max_period / self.min_period + exponents = torch.linspace( + 0, 1, self.D_head // 4, device=device, dtype=dtype + ) # [D//4] range [0, 1] + periods = base**exponents # range [1, max_period / min_period] + periods = periods / base # range [min_period / max_period, 1] + periods = periods * self.max_period # range [min_period, max_period] + self.periods.data = periods + + +# ------------------------------------------------------------------------------- +class Tokenizer(nn.Module): + """Stacked window self‑attention that emits one token per window + by re‑using TransformerEncoderLayer blocks.""" + + def __init__( + self, + embed_dims: int, + window_size: int = 4, + num_heads: int = 4, + num_tokenizer_layers: int = 1, + qkv_bias: bool = True, + use_qk_norm: bool = False, + chunk_size: int = 1024, # max windows per chunk + ): + super().__init__() + self.ws = window_size + self.chunk_size = chunk_size + + # local absolute positional embeddings for [CLS] + patch tokens + self.local_pos_embed = nn.Parameter( + torch.zeros(1, 1 + window_size * window_size, embed_dims) + ) + trunc_normal_(self.local_pos_embed, std=0.02) + + # build N identical TransformerEncoderLayer blocks + self.blocks = nn.ModuleList( + [ + TransformerEncoderLayer2( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=embed_dims * 4, # standard FFN size + qkv_bias=qkv_bias, + use_qk_norm=use_qk_norm, + ) + for _ in range(num_tokenizer_layers) + ] + ) + + # shared CLS token for pooling + self.w_cls = nn.Parameter(torch.zeros(1, 1, embed_dims)) + trunc_normal_(self.w_cls, std=0.02) + self.gradient_checkpointing = False + + def forward( + self, + x: torch.Tensor, + hw: Tuple[int, int], + ) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Args: + x : B, N, C (N = H*W) + hw : (H, W) before reduction + Returns: + x_ : B, (H/ws)*(W/ws), C + hw_: (H/ws, W/ws) + """ + B, N, C = x.shape + H, W = hw + ws = self.ws + assert H % ws == 0 and W % ws == 0, ( + f"Image size {H}×{W} must be divisible by window {ws}." + ) + + # reshape tokens → non‑overlapping windows + x = x.view(B, H, W, C) + + ph, pw = H // ws, W // ws ## ints in eager mode + ph, pw = int(ph), int(pw) ## ints in scripting mode + x = x.view(B, ph, ws, pw, ws, C) # B, H/ws, ws, W/ws, ws, C + x = x.permute(0, 1, 3, 2, 4, 5) # B, H/ws, W/ws, ws, ws, C + x = x.contiguous().view(B * ph * pw, ws * ws, C) # (B*H/ws*W/ws), ws², C)) + + total_windows = x.size(0) + chunk_size = int(min(self.chunk_size, total_windows)) + token_out = x.new_empty(total_windows, C) + + use_ckpt = torch.is_grad_enabled() and self.gradient_checkpointing + + def _run_blocks(t: torch.Tensor) -> torch.Tensor: + for blk in self.blocks: + t = blk(t) + return t + + for i in range(0, total_windows, chunk_size): + chunk = x[i : i + chunk_size] # (m, ws², C) + m = chunk.size(0) + cls = self.w_cls.expand(m, -1, -1) # (m, 1, C) + chunk = torch.cat([cls, chunk], dim=1) # (m, 1+ws², C) + chunk = chunk + self.local_pos_embed # add local PE + + if use_ckpt: + chunk = checkpoint(_run_blocks, chunk, use_reentrant=False) + else: + chunk = _run_blocks(chunk) + + token_out[i : i + m] = chunk[:, 0] # take CLS out + + token = token_out.view(B, ph * pw, C) # (B, (H/ws)*(W + return token, (ph, pw) + + +# ------------------------------------------------------------------------------- +class GroupedQueryAttention(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_kv_heads=None, + input_dims=None, + attn_drop=0.0, + proj_drop=0.0, + qkv_bias=True, + qk_scale=None, + proj_bias=True, + use_qk_norm=True, + v_shortcut=False, + layer_scale_init_value=0.0, + ): + super().__init__() + # Core dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads or num_heads + assert self.num_heads % self.num_kv_heads == 0, ( + "num_kv_heads must divide num_heads" + ) + self.head_dim = embed_dims // num_heads + self.input_dims = input_dims or embed_dims + # Features + self.attn_drop = attn_drop + self.v_shortcut = v_shortcut + self.use_qk_norm = use_qk_norm + + # Attention operation selection + if qk_scale is not None: + scale = qk_scale + else: + scale = self.head_dim**-0.5 + + assert qk_scale is None, "qk_scale is not supported" + self.attn_op = F.scaled_dot_product_attention + + # Q/K/V projections + self.wq = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias) + self.wk = nn.Linear( + self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias + ) + self.wv = nn.Linear( + self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias + ) + + if self.use_qk_norm: + self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-6) + self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-6) + + # Output projection + dropout + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + # Optional LayerScale + if layer_scale_init_value > 0: + self.gamma = LayerScale(embed_dims, scale=layer_scale_init_value) + else: + self.gamma = nn.Identity() + + def apply_rope( + self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor] + ) -> Tuple[Tensor, Tensor]: + # All operations will use the dtype of rope, the output is cast back to the dtype of q and k + q_dtype = q.dtype + k_dtype = k.dtype + sin, cos = rope + rope_dtype = sin.dtype + q = q.to(dtype=rope_dtype) + k = k.to(dtype=rope_dtype) + N = q.shape[-2] + prefix = N - sin.shape[-2] ## extra tokens + assert prefix >= 0 + q_prefix = q[:, :, :prefix, :] + q = self._rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head] + k_prefix = k[:, :, :prefix, :] + k = self._rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head] + q = q.to(dtype=q_dtype) + k = k.to(dtype=k_dtype) + return q, k + + def _rope_rotate_half(self, x: Tensor) -> Tensor: + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x3 -x4 -x5 x0 x1 x2] + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + def _rope_apply(self, x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2] + return (x * cos) + (self._rope_rotate_half(x) * sin) + + def forward(self, x, rope=None): + B, N, _ = x.shape + # Q: (B, N, num_heads, head_dim) + q = self.wq(x).view(B, N, self.num_heads, self.head_dim) + # K/V: (B, N, num_kv_heads, head_dim) + k = self.wk(x).view(B, N, self.num_kv_heads, self.head_dim) + v = self.wv(x).view(B, N, self.num_kv_heads, self.head_dim) + + # (B, heads, N, head_dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Repeat KV heads if group ratio >1 + if self.num_kv_heads != self.num_heads: + factor = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(factor, dim=1) + v = v.repeat_interleave(factor, dim=1) + + if rope is not None: + q, k = self.apply_rope(q, k, rope) + + # Scaled dot-product attention + attn_out = self.attn_op( + q, k, v, dropout_p=self.attn_drop if self.training else 0.0 + ) # (B, num_heads, N, head_dim) + + # Merge heads -> (B, N, embed_dims) + out = attn_out.permute(0, 2, 1, 3).reshape(B, N, self.embed_dims) + + # Output projection + drop + layer scale + out = self.proj(out) + out = self.gamma(self.proj_drop(out)) + + # Optional V-shortcut (only when MQA) + if self.v_shortcut and self.num_kv_heads == 1: + raise NotImplementedError + return out + + +# ------------------------------------------------------------------------------- +class TransformerEncoderLayer2(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_kv_heads=None, + feedforward_channels=None, + drop_rate=0.0, + attn_drop_rate=0.0, + layer_scale_init_value=0.0, + use_qk_norm=True, + qkv_bias=True, + ): + super(TransformerEncoderLayer2, self).__init__() + + self.embed_dims = embed_dims + self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6) + self.attn = GroupedQueryAttention( + embed_dims=embed_dims, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias, + layer_scale_init_value=layer_scale_init_value, + use_qk_norm=use_qk_norm, + ) + + self.ln2 = nn.RMSNorm(self.embed_dims, eps=1e-6) + self.ffn = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x, rope=None): + x = x + self.attn(self.ln1(x), rope=rope) + x = self.ffn(self.ln2(x), identity=x) + return x + + +##----------------------------------- +class Sapiens2(nn.Module): + arch_zoo = { + **dict.fromkeys( + ["sapiens2_0.1b"], + { + "embed_dims": 768, + "num_layers": 12, + "num_heads": 12, + "feedforward_channels": 768 * 4, + "num_tokenizer_layers": 2, + }, + ), + **dict.fromkeys( + ["sapiens2_0.4b"], + { + "embed_dims": 1024, + "num_layers": 24, + "num_heads": 16, + "feedforward_channels": 1024 * 4, + "num_tokenizer_layers": 2, + }, + ), + **dict.fromkeys( + ["sapiens2_0.8b"], + { + "embed_dims": 1280, + "num_layers": 32, + "num_heads": 16, + "feedforward_channels": 1280 * 4, + "num_tokenizer_layers": 3, + }, + ), + **dict.fromkeys( + ["sapiens2_1b"], + { + "embed_dims": 1536, + "num_layers": 40, + "num_heads": 24, + "feedforward_channels": 1536 * 4, + "num_tokenizer_layers": 4, + }, + ), + **dict.fromkeys( + ["sapiens2_5b"], + { + "embed_dims": 2432, + "num_layers": 56, + "num_heads": 32, + "feedforward_channels": 2432 * 4, + "num_tokenizer_layers": 6, + }, + ), + } + + num_extra_tokens = 1 # class token + OUT_TYPES = {"raw", "cls_token", "featmap"} + _supports_gradient_checkpointing = True + + def __init__( + self, + arch="sapiens2_1b", + img_size=(1024, 768), + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0.0, + window_size=4, + use_tokenizer=False, ## 4k resolution + use_qk_norm=True, + qkv_bias=True, + final_norm=True, + out_type="raw", + with_cls_token=True, + layer_scale_init_value=1e-4, ## non zero init to activate layerscale + frozen_stages=-1, + patch_cfg=dict(), + layer_cfgs=dict(), + pos_embed_rope_base: float = 100.0, + pos_embed_rope_min_period: float | None = None, + pos_embed_rope_max_period: float | None = None, + pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", + pos_embed_rope_shift_coords: float | None = None, + pos_embed_rope_jitter_coords: float | None = None, + pos_embed_rope_rescale_coords: float | None = None, + pos_embed_rope_dtype: str = "bf16", + n_storage_tokens: int = 8, + ): + super().__init__() + + arch = arch.lower() + assert arch in set(self.arch_zoo), ( + f"Arch {arch} is not in default archs {set(self.arch_zoo)}" + ) + self.arch_settings = self.arch_zoo[arch] + + self.embed_dims = self.arch_settings["embed_dims"] + self.num_layers = self.arch_settings["num_layers"] + self.patch_size = patch_size + + self.window_size = window_size + img_size = to_2tuple(img_size) + encoder_img_size = ( + (img_size[0] // window_size, img_size[1] // window_size) + if use_tokenizer + else img_size + ) + self.img_size = to_2tuple(encoder_img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=self.img_size, + embed_dims=self.embed_dims, + kernel_size=patch_size, + stride=patch_size, + bias=True, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + self.rope_embed = RopePositionEmbedding( + embed_dim=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + base=pos_embed_rope_base, + min_period=pos_embed_rope_min_period, + max_period=pos_embed_rope_max_period, + normalize_coords=pos_embed_rope_normalize_coords, + shift_coords=pos_embed_rope_shift_coords, + jitter_coords=pos_embed_rope_jitter_coords, + rescale_coords=pos_embed_rope_rescale_coords, + dtype=torch.bfloat16 if pos_embed_rope_dtype == "bf16" else torch.float32, + ) + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError( + f"Unsupported `out_type` {out_type}, please " + f"choose from {self.OUT_TYPES}" + ) + self.out_type = out_type + + if use_tokenizer == True: + self.tokenizer = Tokenizer( + embed_dims=self.embed_dims, + window_size=self.window_size, + num_heads=self.arch_settings["num_heads"], + num_tokenizer_layers=self.arch_settings["num_tokenizer_layers"], + qkv_bias=True, + use_qk_norm=False, + ) + else: + self.tokenizer = None + + # Set cls + storage tokens + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != "cls_token": + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError('with_cls_token must be True when `out_type="cls_token"`.') + + ## registers + self.n_storage_tokens = int(n_storage_tokens) + self.storage_tokens = ( + nn.Parameter(torch.zeros(1, self.n_storage_tokens, self.embed_dims)) + if self.n_storage_tokens > 0 + else None + ) + # how many non-patch tokens are at the front + self.num_extra_tokens = ( + 1 if self.cls_token is not None else 0 + ) + self.n_storage_tokens + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), ( + f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.' + ) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, ( + f"Invalid out_indices {index}" + ) + self.out_indices = out_indices + + self.blocks = nn.Sequential() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + + mhsa_early, mhsa_late = 8, 8 + for i in range(self.num_layers): + if i < mhsa_early or i >= self.num_layers - mhsa_late: + num_kv_heads = None ## use MHSA + else: + num_kv_heads = self.arch_settings["num_heads"] // 2 # Use GQA + + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + num_kv_heads=num_kv_heads, + feedforward_channels=self.arch_settings["feedforward_channels"], + use_qk_norm=use_qk_norm, + layer_scale_init_value=layer_scale_init_value, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + ) + _layer_cfg.update(layer_cfgs[i]) + self.blocks.append(TransformerEncoderLayer2(**_layer_cfg)) + + self.frozen_stages = frozen_stages + + self.final_norm = final_norm + if final_norm: + self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + ## load init weights + self.init_weights() + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self, enable=True): + self.gradient_checkpointing = enable + if self.tokenizer is not None: + self.tokenizer.gradient_checkpointing = enable + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def init_weights(self): + # Initialize class token and storagr token embeddings + if self.with_cls_token: + trunc_normal_(self.cls_token, std=0.02) + + if self.storage_tokens is not None: + trunc_normal_(self.storage_tokens, std=0.02) + + # Apply custom initialization to all submodules + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # Use a truncated normal distribution for linear layer weights + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)): + # Initialize normalization layers to act as an identity function + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + if hasattr(m, "weight") and m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + elif isinstance(m, nn.Conv2d): + # Initialize conv layer weights like linear layers + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _freeze_stages(self): + ## freeze tokenizer + if self.frozen_stages >= 1 and self.tokenizer is not None: + self.tokenizer.eval() + for param in self.tokenizer.parameters(): + param.requires_grad = False + + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + if self.storage_tokens is not None: + self.storage_tokens.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze the last layer norm + if self.frozen_stages == len(self.blocks): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + + x, patch_resolution = self.patch_embed(x) # (B, 256*256, C) + if self.tokenizer is not None: + x, patch_resolution = self.tokenizer(x, patch_resolution) + + # prepend [CLS] and storage tokens + prepend = [] + if self.cls_token is not None: + prepend.append(self.cls_token.expand(B, -1, -1)) + if self.storage_tokens is not None: + prepend.append(self.storage_tokens.expand(B, -1, -1)) + if len(prepend) > 0: + x = torch.cat(prepend + [x], dim=1) + + rope_sincos = self.rope_embed(H=patch_resolution[0], W=patch_resolution[1]) + outs = [] + for i, layer in enumerate(self.blocks): + if self.gradient_checkpointing and self.training: + x = checkpoint(layer, x, rope_sincos, use_reentrant=False) + else: + x = layer(x, rope=rope_sincos) + + if i == len(self.blocks) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == "raw": + return x + if self.out_type == "cls_token": + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens :] + if self.out_type == "featmap": + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + + @property + def norm1(self): + return self.ln1 + + +# ---------------------------------------------------------------------------- +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + inplace: bool = False, + data_format: str = "channels_last", + scale: float = 1e-5, + ): + super().__init__() + assert data_format in ( + "channels_last", + "channels_first", + ), "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * scale) + + def forward(self, x) -> torch.Tensor: + if self.data_format == "channels_first": + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) + if self.inplace: + return x.mul_(self.weight.view(*shape)) + else: + return x * self.weight.view(*shape) + + +# ---------------------------------------------------------------------------- +class PatchEmbed(nn.Module): + def __init__( + self, + in_channels=3, + embed_dims=768, + kernel_size=16, + stride=16, + padding="corner", + dilation=1, + bias=True, + input_size=None, + ): + super().__init__() + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + padding = 0 + padding = to_2tuple(padding) + + self.projection = nn.Conv2d( + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if input_size: + input_size = to_2tuple(input_size) + self.init_input_size = input_size + h_out = ( + input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + w_out = ( + input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + return x, out_size + + +# ---------------------------------------------------------------------------- +class SwiGLUFFN(nn.Module): + """SwiGLU FFN layer. + https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0.0, + bias: bool = True, + add_identity: bool = True, + ) -> None: + super().__init__() + self.embed_dims = embed_dims + self.out_dims = out_dims or embed_dims + hidden_dims = feedforward_channels or embed_dims + + self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) + self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale(dim=embed_dims, scale=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + self.add_identity = add_identity + + def forward( + self, x: torch.Tensor, identity: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + out = self.w3(hidden) + out = self.gamma2(out) + + if self.out_dims != self.embed_dims or not self.add_identity: + # due to the dimension inconsistence or user setting + # not to apply residual operation + return out + + if identity is None: + identity = x + return identity + out + + +# ---------------------------------------------------------------------------- +_IMAGENET_MEAN = (0.485, 0.456, 0.406) +_IMAGENET_STD = (0.229, 0.224, 0.225) + + +def imagenet_normalize(tensors_0_1: torch.Tensor) -> torch.Tensor: + """Apply ImageNet normalization to a (B, C, H, W) RGB tensor in [0, 1].""" + mean = torch.as_tensor( + _IMAGENET_MEAN, dtype=tensors_0_1.dtype, device=tensors_0_1.device + ).view(1, 3, 1, 1) + std = torch.as_tensor( + _IMAGENET_STD, dtype=tensors_0_1.dtype, device=tensors_0_1.device + ).view(1, 3, 1, 1) + return (tensors_0_1 - mean) / std + + +# ---------------------------------------------------------------------------- +class MattingHead(nn.Module): + """Matting decode head from + https://github.com/facebookresearch/sapiens2/blob/main/sapiens/dense/src/models/heads/matting_head.py + + Predicts a 4-channel output: pre-multiplied foreground RGB (channels 0-2) + and soft alpha matte (channel 3), all in [0, 1] after sigmoid. + """ + + def __init__( + self, + in_channels: int = 1536, + upsample_channels: Sequence[int] = (768, 512, 256, 128), + conv_out_channels: Optional[Sequence[int]] = (64, 32, 16), + conv_kernel_sizes: Optional[Sequence[int]] = (3, 3, 3), + out_channels: int = 4, + ): + super().__init__() + self.in_channels = in_channels + + self.input_conv = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm2d(in_channels), + nn.SiLU(inplace=True), + ) + + up_blocks = [] + cur_ch = in_channels + for out_ch in upsample_channels: + up_blocks.append( + nn.Sequential( + nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1), + nn.PixelShuffle(2), + nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ) + ) + cur_ch = out_ch + self.upsample_blocks = nn.Sequential(*up_blocks) + + conv_layers = [] + if conv_out_channels and conv_kernel_sizes: + for out_ch, k in zip(conv_out_channels, conv_kernel_sizes): + conv_layers.extend( + [ + nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2), + nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ] + ) + cur_ch = out_ch + self.conv_layers = nn.Sequential(*conv_layers) + + self.conv_matting = nn.Conv2d(cur_ch, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.input_conv(x) + x = self.upsample_blocks(x) + x = self.conv_layers(x) + return self.conv_matting(x).sigmoid() + + +# ---------------------------------------------------------------------------- +class Sapiens2Matting(nn.Module): + """Sapiens2 backbone + MattingHead for human image matting. + + Reference: https://github.com/facebookresearch/sapiens2/blob/main/docs/MATTING.md + """ + + _ARCH_TO_EMBED_DIM = { + "sapiens2_0.1b": 768, + "sapiens2_0.4b": 1024, + "sapiens2_0.8b": 1280, + "sapiens2_1b": 1536, + "sapiens2_5b": 2432, + } + + def __init__( + self, + arch: str = "sapiens2_1b", + img_size: Tuple[int, int] = (1024, 768), + patch_size: int = 16, + ): + super().__init__() + arch = arch.lower() + if arch not in self._ARCH_TO_EMBED_DIM: + raise ValueError(f"Unsupported arch {arch}") + + self.arch = arch + self.img_size = to_2tuple(img_size) + self.patch_size = patch_size + + self.backbone = Sapiens2( + arch=arch, + img_size=img_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + ) + self.decode_head = MattingHead( + in_channels=self._ARCH_TO_EMBED_DIM[arch], + upsample_channels=(768, 512, 256, 128), + conv_out_channels=(64, 32, 16), + conv_kernel_sizes=(3, 3, 3), + out_channels=4, + ) + + @classmethod + def from_pretrained( + cls, + repo_id: str = "facebook/sapiens2-matting-1b", + filename: str = "sapiens2_1b_matting.safetensors", + arch: str = "sapiens2_1b", + img_size: Tuple[int, int] = (1024, 768), + patch_size: int = 16, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "Sapiens2Matting": + import huggingface_hub + from safetensors.torch import load_file + + safetensors_path = os.path.join(MODELS_PATH, "sapiens2", filename) + if not os.path.exists(safetensors_path): + print(f"Downloading pretrained weights from HuggingFace Hub: {repo_id}/{filename}...") + os.makedirs(os.path.dirname(safetensors_path), exist_ok=True) + safetensors_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=os.path.join(MODELS_PATH, "sapiens2"), + ) + model = cls(arch=arch, img_size=img_size, patch_size=patch_size) + state_dict = load_file(safetensors_path) + model.load_state_dict(state_dict) + model.eval() + if device is not None or dtype is not None: + model.to(device=device, dtype=dtype) + return model + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + @torch.no_grad() + def forward(self, image, max_res: int = 1024): + """Take a PIL image and return a PIL alpha-matte mask in RGB mode at + the original input size. The image is run through the model at its + native aspect ratio, snapped to a multiple of patch_size and capped + at max_res*max_res pixels.""" + from torchvision import transforms + + p = self.patch_size + w, h = image.size + target_h, target_w = h, w + if target_h * target_w > max_res * max_res: + scale = math.sqrt((max_res * max_res) / (target_h * target_w)) + target_h = int(target_h * scale) + target_w = int(target_w * scale) + target_h = max(p, (target_h // p) * p) + target_w = max(p, (target_w // p) * p) + + transform_image = transforms.Compose( + [ + transforms.Resize((target_h, target_w)), + transforms.ToTensor(), + transforms.Normalize(_IMAGENET_MEAN, _IMAGENET_STD), + ] + ) + input_images = ( + transform_image(image).unsqueeze(0).to(self.device, dtype=self.dtype) + ) + + feat = self.backbone(input_images)[0] + out = self.decode_head(feat) # (1, 4, H, W) in [0, 1] + alpha = out[0, 3].float().cpu() + + mask = transforms.ToPILImage()(alpha) + mask = mask.resize(image.size).convert("RGB") + return mask diff --git a/ai-toolkit/toolkit/models/single_value_adapter.py b/ai-toolkit/toolkit/models/single_value_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa73944427f4f9be55602b5fadfdc8f95cd5987 --- /dev/null +++ b/ai-toolkit/toolkit/models/single_value_adapter.py @@ -0,0 +1,399 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING + +from diffusers import Transformer2DModel + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class SingleValueAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + + adapter_hidden_states = self.conditional_embeds + if adapter_hidden_states.shape[0] < batch_size: + # doing cfg + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # needs to be shape (batch, 1, 1) + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SingleValueAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + num_values: int = 1, + ): + super(SingleValueAdapter, self).__init__() + is_pixart = sd.is_pixart + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.token_size = num_values + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + # if is_pixart: + # to_k_bias = unet_sd[layer_name + ".to_k.bias"] + # to_v_bias = unet_sd[layer_name + ".to_v.bias"] + # else: + # to_k_bias = None + # to_v_bias = None + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.token_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + # if is_pixart: + # to_k_bias = torch.cat([ + # to_k_bias, + # torch.zeros(self.token_size - to_k_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + # to_v_bias = torch.cat([ + # to_v_bias, + # torch.zeros(self.token_size - to_v_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + elif to_k_adapter.shape[1] > self.token_size: + to_k_adapter = to_k_adapter[:, :self.token_size] + to_v_adapter = to_v_adapter[:, :self.token_size] + # if is_pixart: + # to_k_bias = to_k_bias[:self.token_size] + # to_v_bias = to_v_bias[:self.token_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + # if is_pixart: + # to_k_bias = to_k_bias + # to_v_bias = to_v_bias + + weights = { + "to_k_adapter.weight": to_k_adapter * 0.01, + "to_v_adapter.weight": to_v_adapter * 0.01, + } + # if is_pixart: + # weights["to_k_adapter.bias"] = to_k_bias + # weights["to_v_adapter.bias"] = to_v_bias + + attn_procs[name] = SingleValueAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.token_size, + has_bias=False, + ) + attn_procs[name].load_state_dict(weights) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList([ + transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) + ]) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, input): + return input diff --git a/ai-toolkit/toolkit/models/size_agnostic_feature_encoder.py b/ai-toolkit/toolkit/models/size_agnostic_feature_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a716aec504503afa2c103506876c91bc2b617d07 --- /dev/null +++ b/ai-toolkit/toolkit/models/size_agnostic_feature_encoder.py @@ -0,0 +1,256 @@ +import os +from typing import Union, Optional + +import torch +import torch.nn as nn +from transformers.image_processing_utils import BaseImageProcessor + + +class SAFEReducerBlock(nn.Module): + """ + This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative + So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual + along for the ride. This is done to preserve information. + """ + def __init__(self, channels=512): + super(SAFEReducerBlock, self).__init__() + self.channels = channels + + activation = nn.GELU + + self.reducer = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + activation(), + nn.BatchNorm2d(channels), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + activation(), + nn.BatchNorm2d(channels), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2) + + def forward(self, x): + res = self.residual_shrink(x) + reduced = self.reducer(x) + return reduced + res + + +class SizeAgnosticFeatureEncoder(nn.Module): + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + ): + super(SizeAgnosticFeatureEncoder, self).__init__() + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.channels = channels + self.reducer_channels = reducer_channels + self.gradient_checkpointing = False + + # input is minimum of (bs, 3, 256, 256) + + subpixel_channels = in_channels * downscale_factor ** 2 + + # PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32) + # PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16) + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32) + + self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32) + + # run as many times as needed to get to min feature of 8 on the smallest dimension + self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8) + + self.reduced_out = nn.Conv2d( + reducer_channels, self.channels, kernel_size=3, padding=1 + ) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8) + + # (bs, 2048, 8, 8) + self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4) + self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2) + + # reduce mean of dims 2 and 3 + self.adaptive_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + ) + + # (bs, 2048) + # linear layer to (bs, self.num_vectors * self.num_tokens) + self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens) + + # (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144) + + def forward(self, x): + x = self.unshuffle(x) + x = self.conv_in(x) + + while True: + # reduce until we get as close to 8x8 as possible without going under + x = self.reducer(x) + if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8: + break + + x = self.reduced_out(x) + x = self.block1(x) + x = self.block2(x) + x = self.adaptive_pool(x) + x = self.fc1(x) + + # reshape + x = x.view(-1, self.num_tokens, self.num_vectors) + + return x + + +class SAFEIPReturn: + def __init__(self, pixel_values): + self.pixel_values = pixel_values + + +class SAFEImageProcessor(BaseImageProcessor): + def __init__( + self, + max_size=1024, + min_size=256, + **kwargs + ): + super().__init__(**kwargs) + self.max_size = max_size + self.min_size = min_size + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + # not needed + return cls(**kwargs) + + def __call__( + self, + images, + **kwargs + ): + # TODO allow for random resizing + # comes in 0 - 1 range + # if any size is smaller than 256, resize to 256 + # if any size is larger than max_size, resize to max_size + if images.min() < -0.3 or images.max() > 1.3: + raise ValueError( + "images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format( + images.min(), images.max() + )) + + # make sure we have (bs, 3, h, w) + while len(images.shape) < 4: + images = images.unsqueeze(0) + + # expand to 3 channels if we only have 1 channel + if images.shape[1] == 1: + images = torch.cat([images, images, images], dim=1) + + width = images.shape[3] + height = images.shape[2] + + if width < self.min_size or height < self.min_size: + # scale up so that the smallest size is 256 + if width < height: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + else: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + elif width > self.max_size or height > self.max_size: + # scale down so that the largest size is max_size but do not shrink the other size below 256 + if width > height: + new_width = self.max_size + new_height = int(height * (self.max_size / width)) + else: + new_height = self.max_size + new_width = int(width * (self.max_size / height)) + + if new_width < self.min_size: + new_width = self.min_size + new_height = int(height * (self.min_size / width)) + + if new_height < self.min_size: + new_height = self.min_size + new_width = int(width * (self.min_size / height)) + + images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', + align_corners=False) + + # if wither side is not divisible by 16, mirror pad to make it so + if images.shape[2] % 16 != 0: + pad = 16 - (images.shape[2] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect') + if images.shape[3] % 16 != 0: + pad = 16 - (images.shape[3] % 16) + pad1 = pad // 2 + pad2 = pad - pad1 + images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect') + + return SAFEIPReturn(images) + + +class SAFEVMConfig: + def __init__( + self, + in_channels=3, + num_tokens=8, + num_vectors=768, + reducer_channels=512, + channels=2048, + downscale_factor: int = 8, + **kwargs + ): + self.in_channels = in_channels + self.num_tokens = num_tokens + self.num_vectors = num_vectors + self.reducer_channels = reducer_channels + self.channels = channels + self.downscale_factor = downscale_factor + self.image_size = 224 + + self.hidden_size = num_vectors + self.projection_dim = num_vectors + + +class SAFEVMReturn: + def __init__(self, output): + self.output = output + # todo actually do hidden states. This is just for code compatability for now + self.hidden_states = [output for _ in range(13)] + + +class SAFEVisionModel(SizeAgnosticFeatureEncoder): + def __init__(self, **kwargs): + self.config = SAFEVMConfig(**kwargs) + self.image_size = None + # super().__init__(**kwargs) + super(SAFEVisionModel, self).__init__(**kwargs) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + # not needed + return SAFEVisionModel(**kwargs) + + def forward(self, x, **kwargs): + return SAFEVMReturn(super().forward(x)) diff --git a/ai-toolkit/toolkit/models/sref.py b/ai-toolkit/toolkit/models/sref.py new file mode 100644 index 0000000000000000000000000000000000000000..7290e894b270642b48dc9b639e20ead56e057337 --- /dev/null +++ b/ai-toolkit/toolkit/models/sref.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + + +class SrefImageEncoder(torch.nn.Module): + def __init__( + self, + input_features: int = 1152, + input_tokens: int = 512, + output_tokens: int = 512, + output_features: int = 4096, + intermediate_size: int = 4096, + num_digits: int = 10, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.input_features = input_features + self.device = device + self.dtype = dtype + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.output_features = output_features + self.intermediate_size = intermediate_size + self.num_digits = num_digits + + self.proj_in = nn.Linear( + input_features, intermediate_size, dtype=dtype) + # (bs, num_digits, intermediate_size) + self.conv_pool = nn.Conv1d(input_tokens, num_digits, 1, dtype=dtype) + self.linear_pool = nn.Linear( + intermediate_size, 1, dtype=dtype) # (bs, num_digits, 1) + # do sigmoid for digits 0.0-1.0 = (0 to 10) Always floor when rounding digits so you get 0-9 + self.flatten = nn.Flatten() # (bs, num_digits * intermediate_size) + + # a numeric sref would come in here with num_digits + self.sref_in = nn.Linear(num_digits, intermediate_size, dtype=dtype) + self.fc1 = nn.Linear(intermediate_size, intermediate_size, dtype=dtype) + self.fc2 = nn.Linear(intermediate_size, intermediate_size, dtype=dtype) + + self.proj_out = nn.Linear( + intermediate_size, output_features * output_tokens, dtype=dtype) + + def forward(self, siglip_embeds) -> torch.Tensor: + x = self.proj_in(siglip_embeds) + x = torch.nn.functional.silu(x) + x = self.conv_pool(x) + x = self.linear_pool(x) + x = torch.sigmoid(x) + + sref = self.flatten(x) + + x = self.sref_in(sref) + x = torch.nn.functional.silu(x) + x = self.fc1(x) + x = torch.nn.functional.silu(x) + x = self.fc2(x) + x = torch.nn.functional.silu(x) + x = self.proj_out(x) + + return x diff --git a/ai-toolkit/toolkit/models/subpixel_adapter.py b/ai-toolkit/toolkit/models/subpixel_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4ed6384dd02c4d662271429c4940a1c5bf7f33 --- /dev/null +++ b/ai-toolkit/toolkit/models/subpixel_adapter.py @@ -0,0 +1,309 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +# weakref +from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + + +class InOutModule(torch.nn.Module): + def __init__( + self, + adapter: 'SubpixelAdapter', + orig_layer: torch.nn.Linear, + in_channels=64, + out_channels=3072 + ): + super().__init__() + # only do the weight for the new input. We combine with the original linear layer + self.x_embedder = torch.nn.Linear( + in_channels, + out_channels, + bias=True, + ) + + self.proj_out = torch.nn.Linear( + out_channels, + in_channels, + bias=True, + ) + # make sure the weight is float32 + self.x_embedder.weight.data = self.x_embedder.weight.data.float() + self.x_embedder.bias.data = self.x_embedder.bias.data.float() + + self.proj_out.weight.data = self.proj_out.weight.data.float() + self.proj_out.bias.data = self.proj_out.bias.data.float() + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) + + @classmethod + def from_model( + cls, + model: FluxTransformer2DModel, + adapter: 'SubpixelAdapter', + num_channels: int = 768, + downscale_factor: int = 8 + ): + if model.__class__.__name__ == 'FluxTransformer2DModel': + + x_embedder: torch.nn.Linear = model.x_embedder + proj_out: torch.nn.Linear = model.proj_out + in_out_module = cls( + adapter, + orig_layer=x_embedder, + in_channels=num_channels, + out_channels=x_embedder.out_features, + ) + + # hijack the forward method + x_embedder._orig_ctrl_lora_forward = x_embedder.forward + x_embedder.forward = in_out_module.in_forward + proj_out._orig_ctrl_lora_forward = proj_out.forward + proj_out.forward = in_out_module.out_forward + + # update the config of the transformer + model.config.in_channels = num_channels + model.config["in_channels"] = num_channels + model.config.out_channels = num_channels + model.config["out_channels"] = num_channels + + # if the shape matches, copy the weights + if x_embedder.weight.shape == in_out_module.x_embedder.weight.shape: + in_out_module.x_embedder.weight.data = x_embedder.weight.data.clone().float() + in_out_module.x_embedder.bias.data = x_embedder.bias.data.clone().float() + in_out_module.proj_out.weight.data = proj_out.weight.data.clone().float() + in_out_module.proj_out.bias.data = proj_out.bias.data.clone().float() + + # replace the vae of the model + sd = adapter.sd_ref() + sd.vae = AutoencoderPixelMixer( + in_channels=3, + downscale_factor=downscale_factor + ) + + sd.pipeline.vae = sd.vae + + return in_out_module + else: + raise ValueError("Model not supported") + + @property + def is_active(self): + return self.adapter_ref().is_active + + + def in_forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + + x = x.to(self.x_embedder.weight.device, dtype=self.x_embedder.weight.dtype) + + x = self.x_embedder(x) + + x = x.to(orig_device, dtype=orig_dtype) + return x + + def out_forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + + x = x.to(self.proj_out.weight.device, dtype=self.proj_out.weight.dtype) + + x = self.proj_out(x) + + x = x.to(orig_device, dtype=orig_dtype) + return x + + + +class SubpixelAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + config: 'AdapterConfig', + train_config: 'TrainConfig' + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.device_torch = sd.device_torch + self.control_lora = None + + if self.network_config is not None: + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + # always ignore x_embedder + network_kwargs['ignore_if_contains'].append('transformer.x_embedder') + network_kwargs['ignore_if_contains'].append('transformer.proj_out') + + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() + + downscale_factor = config.subpixel_downscale_factor + if downscale_factor == 8: + num_channels = 768 + elif downscale_factor == 16: + num_channels = 3072 + else: + raise ValueError( + f"downscale_factor {downscale_factor} not supported" + ) + + self.in_out: InOutModule = InOutModule.from_model( + sd.unet_unwrapped, + self, + num_channels=num_channels, # packed channels + downscale_factor=downscale_factor + ) + self.in_out.to(self.device_torch) + + def get_params(self): + if self.control_lora is not None: + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.in_out.to(torch.float32) + + params += list(self.in_out.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + img_embedder_sd = {} + for key, value in state_dict.items(): + if "transformer.x_embedder" in key: + new_key = key.replace("transformer.", "") + img_embedder_sd[new_key] = value + elif "transformer.proj_out" in key: + new_key = key.replace("transformer.", "") + img_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading + if self.control_lora is not None: + self.control_lora.load_weights(lora_sd) + # automatically upgrade the x imbedder if more dims are added + self.in_out.load_state_dict(img_embedder_sd, strict=False) + + def get_state_dict(self): + if self.control_lora is not None: + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + # todo make sure we match loras elseware. + img_embedder_sd = self.in_out.state_dict() + for key, value in img_embedder_sd.items(): + lora_sd[f"transformer.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/ai-toolkit/toolkit/models/te_adapter.py b/ai-toolkit/toolkit/models/te_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..37396dd792e997b29bb53f2dd89a7c75ced81194 --- /dev/null +++ b/ai-toolkit/toolkit/models/te_adapter.py @@ -0,0 +1,452 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING +from transformers import T5EncoderModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection +from toolkit import train_tools +from toolkit.prompt_utils import PromptEmbeds +from diffusers import Transformer2DModel +from toolkit.util.ip_adapter_utils import AttnProcessor2_0 + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + + +class TEAdapterCaptionProjection(nn.Module): + def __init__(self, caption_channels, adapter: 'TEAdapter'): + super().__init__() + in_features = caption_channels + self.adapter_ref: weakref.ref = weakref.ref(adapter) + sd = adapter.sd_ref() + self.parent_module_ref = weakref.ref(sd.unet.caption_projection) + parent_module = self.parent_module_ref() + self.linear_1 = nn.Linear( + in_features=in_features, + out_features=parent_module.linear_1.out_features, + bias=True + ) + self.linear_2 = nn.Linear( + in_features=parent_module.linear_2.in_features, + out_features=parent_module.linear_2.out_features, + bias=True + ) + + # save the orig forward + parent_module.linear_1.orig_forward = parent_module.linear_1.forward + parent_module.linear_2.orig_forward = parent_module.linear_2.forward + + # replace original forward + parent_module.orig_forward = parent_module.forward + parent_module.forward = self.forward + + + @property + def is_active(self): + return self.adapter_ref().is_active + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def forward(self, caption): + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds.text_embeds + # check if we are doing unconditional + if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]: + # concat unconditional to match the hidden state batch size + if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: + unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0) + else: + unconditional = self.unconditional_embeds.text_embeds + adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0) + hidden_states = self.linear_1(adapter_hidden_states) + hidden_states = self.parent_module_ref().act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + else: + return self.parent_module_ref().orig_forward(caption) + + +class TEAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None, + adapter_hidden_size=None, layer_name=None): + super().__init__() + self.layer_name = layer_name + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False) + + @property + def is_active(self): + return self.adapter_ref().is_active + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds.text_embeds + # check if we are doing unconditional + if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]: + # concat unconditional to match the hidden state batch size + if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: + unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0) + else: + unconditional = self.unconditional_embeds.text_embeds + adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0) + # for ip-adapter + key = self.to_k_adapter(adapter_hidden_states) + value = self.to_v_adapter(adapter_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + try: + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + except RuntimeError: + raise RuntimeError(f"key shape: {key.shape}, value shape: {value.shape}") + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + # remove attn mask if doing clip + if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip": + attention_mask = None + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class TEAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + te: Union[T5EncoderModel], + tokenizer: CLIPTokenizer + ): + super(TEAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.te_ref: weakref.ref = weakref.ref(te) + self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) + self.adapter_modules = [] + self.caption_projection = None + self.embeds_store = [] + is_pixart = sd.is_pixart + + if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5": + self.token_size = self.te_ref().config.d_model + else: + self.token_size = self.te_ref().config.hidden_size + + # add text projection if is sdxl + self.text_projection = None + if sd.is_xl: + clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0] + self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False) + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_dict_map = { + + } + module_idx = 0 + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + attn_processor_names = [] + + blocks = [] + transformer_blocks = [] + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ + sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.token_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + elif to_k_adapter.shape[1] > self.token_size: + to_k_adapter = to_k_adapter[:, :self.token_size] + to_v_adapter = to_v_adapter[:, :self.token_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + + # todo resize to the TE hidden size + weights = { + "to_k_adapter.weight": to_k_adapter, + "to_v_adapter.weight": to_v_adapter, + } + + if self.sd_ref().is_pixart: + # pixart is much more sensitive + weights = { + "to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01, + "to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01, + } + + attn_procs[name] = TEAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.adapter_ref().config.num_tokens, + adapter=self, + adapter_hidden_size=self.token_size, + layer_name=layer_name + ) + attn_procs[name].load_state_dict(weights) + self.adapter_modules.append(attn_procs[name]) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn2.processor for i in + range(len(transformer.transformer_blocks)) + ]) + self.caption_projection = TEAdapterCaptionProjection( + caption_channels=self.token_size, + adapter=self, + ) + + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def encode_text(self, text): + te: T5EncoderModel = self.te_ref() + tokenizer: T5Tokenizer = self.tokenizer_ref() + attn_mask_float = None + + # input_ids = tokenizer( + # text, + # max_length=77, + # padding="max_length", + # truncation=True, + # return_tensors="pt", + # ).input_ids.to(te.device) + # outputs = te(input_ids=input_ids) + # outputs = outputs.last_hidden_state + if self.adapter_ref().config.text_encoder_arch == "clip": + embeds = train_tools.encode_prompts( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + attention_mask = torch.ones(embeds.shape[:2], device=embeds.device) + + elif self.adapter_ref().config.text_encoder_arch == "pile-t5": + # just use aura pile + embeds, attention_mask = train_tools.encode_prompts_auraflow( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + + else: + embeds, attention_mask = train_tools.encode_prompts_pixart( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + if attention_mask is not None: + attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype) + if self.text_projection is not None: + # pool the output of embeds ignoring 0 in the attention mask + if attn_mask_float is not None: + pooled_output = embeds * attn_mask_float.unsqueeze(-1) + else: + pooled_output = embeds + + # reduce along dim 1 while maintaining batch and dim 2 + pooled_output_sum = pooled_output.sum(dim=1) + + if attn_mask_float is not None: + attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1) + + pooled_output = pooled_output_sum / attn_mask_sum + + pooled_embeds = self.text_projection(pooled_output) + + prompt_embeds = PromptEmbeds( + (embeds, pooled_embeds), + attention_mask=attention_mask, + ).detach() + + else: + + prompt_embeds = PromptEmbeds( + embeds, + attention_mask=attention_mask, + ).detach() + + return prompt_embeds + + + + def forward(self, input): + return input diff --git a/ai-toolkit/toolkit/models/te_aug_adapter.py b/ai-toolkit/toolkit/models/te_aug_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e74beacc5bbccd09d05b95fbe3e6ddf4980b9b --- /dev/null +++ b/ai-toolkit/toolkit/models/te_aug_adapter.py @@ -0,0 +1,247 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING, Optional, Tuple + +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer +from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention + +from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + + +class TEAugAdapterCLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.attn_module_ref: weakref.ref = weakref.ref(attn_module) + self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) + self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) + # copy the weights from the original module + self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01 + self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01 + #reset the bias + self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001 + self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001 + + self.zipper = ZipperModule( + in_size=attn_module.embed_dim, + in_tokens=77 * 2, + out_size=attn_module.embed_dim, + out_tokens=77, + hidden_size=attn_module.embed_dim, + hidden_tokens=77, + ) + # self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data) + # self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data) + # #reset the bias + # self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data) + # self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data) + + # replace the original forward with our forward + self.original_forward = attn_module.forward + attn_module.forward = self.forward + + + @property + def is_active(self): + return self.adapter_ref().is_active + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + attn_module = self.attn_module_ref() + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = attn_module.q_proj(hidden_states) * attn_module.scale + key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz) + value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim) + query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref() + if self.adapter_ref().is_active and adapter.conditional_embeds is not None: + # apply the adapter + + if adapter.is_unconditional_run: + embeds = adapter.unconditional_embeds + else: + embeds = adapter.conditional_embeds + # if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well + if embeds.size(0) != bsz: + embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0) + + key_states_raw = self.k_proj_adapter(embeds) + key_states = attn_module._shape(key_states_raw, -1, bsz) + value_states_raw = self.v_proj_adapter(embeds) + value_states = attn_module._shape(value_states_raw, -1, bsz) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) + attn_output_adapter = torch.bmm(attn_probs, value_states) + + if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): + raise ValueError( + f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" + f" {attn_output_adapter.size()}" + ) + + attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) + attn_output_adapter = attn_output_adapter.transpose(1, 2) + attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim) + + attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1)) + + # attn_output_adapter = attn_module.out_proj(attn_output_adapter) + attn_output = attn_output + attn_output_adapter + + attn_output = attn_module.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + +class TEAugAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + ): + super(TEAugAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + + if isinstance(sd.text_encoder, list): + raise ValueError("Dual text encoders is not yet supported") + + # dim will come from text encoder + # dim = sd.unet.config['cross_attention_dim'] + text_encoder: CLIPTextModel = sd.text_encoder + dim = text_encoder.config.hidden_size + + clip_encoder: CLIPEncoder = text_encoder.text_model.encoder + # dim = clip_encoder.layers[-1].self_attn + + if hasattr(adapter.vision_encoder.config, 'hidden_sizes'): + embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] + else: + embedding_dim = adapter.vision_encoder.config.hidden_size + + image_encoder_state_dict = adapter.vision_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if adapter.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] + + out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens + self.image_proj_model = ZipperModule( + in_size=embedding_dim, + in_tokens=in_tokens, + out_size=dim, + out_tokens=out_tokens, + hidden_size=dim, + hidden_tokens=out_tokens, + ) + # init adapter modules + attn_procs = {} + for idx, layer in enumerate(clip_encoder.layers): + name = f"clip_attention.{idx}" + attn_procs[name] = TEAugAdapterCLIPAttention( + layer.self_attn, + self + ) + + self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values())) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + + def forward(self, input): + # # apply the adapter + input = self.image_proj_model(input) + # self.embeds = input + return input diff --git a/ai-toolkit/toolkit/models/tipsv2.py b/ai-toolkit/toolkit/models/tipsv2.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1fd65f4104c61b8b38f42eaa6d3f36eb476cbb --- /dev/null +++ b/ai-toolkit/toolkit/models/tipsv2.py @@ -0,0 +1,1044 @@ +"""Local implementation of google/tipsv2-b14-dpt. + +Self-contained port of the remote `trust_remote_code=True` model into ai-toolkit. +Includes vision encoder + DPT depth/normals/segmentation heads, with optional +gradient checkpointing on the vision transformer blocks. The text encoder is +intentionally not included — only the dense-prediction stack is used here. + +Original remote code: https://huggingface.co/google/tipsv2-b14-dpt + https://huggingface.co/google/tipsv2-b14 +""" + +import functools +import math +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + + +# ───────────────────────── Vision Transformer ────────────────────────────── + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + image_hw = _make_2tuple(img_size) + patch_hw = _make_2tuple(patch_size) + self.img_size = image_hw + self.patch_size = patch_hw + self.patches_resolution = ( + image_hw[0] // patch_hw[0], + image_hw[1] // patch_hw[1], + ) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + self.flatten_embedding = flatten_embedding + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.shape + ph, pw = self.patch_size + assert h % ph == 0, f"Input height {h} not divisible by patch {ph}" + assert w % pw == 0, f"Input width {w} not divisible by patch {pw}" + x = self.proj(x) + h, w = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, h, w, self.embed_dim) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, n, c = x.shape + qkv = ( + self.qkv(x) + .reshape(b, n, 3, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + # Use SDPA — drops the manual attention matmul + softmax and supports flash on cuda. + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0 + ) + x = x.transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, dim: int, init_values: Union[float, torch.Tensor] = 1e-5 + ) -> None: + super().__init__() + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.gamma + + +class _DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0 or not self.training: + return x + keep = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = x.new_empty(shape).bernoulli_(keep) + if keep > 0.0: + mask.div_(keep) + return x * mask + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = _DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = _DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """DINOv2-style ViT used as the TIPSv2 vision backbone.""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 14, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + ffn_bias: bool = True, + proj_bias: bool = True, + drop_path_rate: float = 0.0, + init_values: Optional[float] = 1.0, + ffn_layer: str = "mlp", + num_register_tokens: int = 1, + interpolate_antialias: bool = True, + interpolate_offset: float = 0.0, + ): + super().__init__() + norm_layer = functools.partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.gradient_checkpointing = False + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if ffn_layer != "mlp": + raise NotImplementedError( + f"ffn_layer={ffn_layer!r} not supported in local port" + ) + + dpr = [drop_path_rate * i / max(depth - 1, 1) for i in range(depth)] + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + ffn_layer=Mlp, + init_values=init_values, + ) + for i in range(depth) + ] + ) + # Maintain weight-key compat with the upstream non-chunked branch. + self.chunked_blocks = False + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + # ---- gradient checkpointing toggles ------------------------------------ + + def gradient_checkpointing_enable(self, **_kwargs) -> None: + self.gradient_checkpointing = True + + def gradient_checkpointing_disable(self) -> None: + self.gradient_checkpointing = False + + enable_gradient_checkpointing = gradient_checkpointing_enable + disable_gradient_checkpointing = gradient_checkpointing_disable + + # ---- positional embedding / token prep --------------------------------- + + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + num_patches = self.pos_embed.shape[1] - 1 + if npatch == num_patches and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + side = int(math.sqrt(num_patches)) + assert num_patches == side * side + kwargs = {} + if self.interpolate_offset: + kwargs["scale_factor"] = ( + float(w0 + self.interpolate_offset) / side, + float(h0 + self.interpolate_offset) / side, + ) + else: + kwargs["size"] = (w0, h0) + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, side, side, dim).permute(0, 3, 1, 2), + mode="bilinear", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks( + self, x: torch.Tensor, masks: Optional[torch.Tensor] = None + ) -> torch.Tensor: + _, _, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + return x + + # ---- block runner with optional checkpointing -------------------------- + + def _run_blocks( + self, x: torch.Tensor, collect_indices: Optional[Sequence[int]] = None + ): + collected = [] if collect_indices is not None else None + use_ckpt = self.gradient_checkpointing and self.training + for i, blk in enumerate(self.blocks): + if use_ckpt: + x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + if collected is not None and i in collect_indices: + collected.append(x) + return (x, collected) if collected is not None else x + + # ---- public forwards --------------------------------------------------- + + def forward_features( + self, x: torch.Tensor, masks: Optional[torch.Tensor] = None + ) -> dict: + x = self.prepare_tokens_with_masks(x, masks) + x = self._run_blocks(x) + x_norm = self.norm(x) + return { + "x_norm_1st_clstoken": x_norm[:, :1], + "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence[int]] = 1, + reshape: bool = False, + return_class_token: bool = False, + norm: bool = True, + ): + x_in = x + x = self.prepare_tokens_with_masks(x) + total = len(self.blocks) + indices = list(range(total - n, total)) if isinstance(n, int) else list(n) + _, outputs = self._run_blocks(x, collect_indices=indices) + # Preserve the requested ordering. + order = {idx: pos for pos, idx in enumerate(sorted(indices))} + outputs = [outputs[order[idx]] for idx in indices] + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + b, _, w, h = x_in.shape + outputs = [ + out.reshape(b, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward_hidden_states(self, x: torch.Tensor): + """Per-layer hidden states for use as a perceptual feature stack. + + Returns a tuple ``(embeddings, block_1_out, ..., block_L_out)`` of length + ``depth + 1`` (HuggingFace ``output_hidden_states`` convention), each + ``(B, 1 + num_register_tokens + num_patches, embed_dim)``. No final norm is + applied; intermediate layers are returned raw.""" + x = self.prepare_tokens_with_masks(x) + hidden_states = [x] + # Gate checkpointing on grad tracking (not self.training) so the no_grad + # target pass doesn't pay for wasted recompute. + use_ckpt = self.gradient_checkpointing and torch.is_grad_enabled() + for blk in self.blocks: + if use_ckpt: + x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + hidden_states.append(x) + return tuple(hidden_states) + + def forward(self, x: torch.Tensor, is_training: bool = False): + ret = self.forward_features(x) + if is_training: + return ret + return ( + self.head(ret["x_norm_1st_clstoken"]), + self.head(ret["x_norm_2nd_clstoken"]), + ret["x_norm_patchtokens"], + ) + + +def _vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + num_register_tokens=1, + **kwargs, + ) + + +def _vit_so400m(patch_size: int = 14, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=4304 / 1152, + num_register_tokens=1, + **kwargs, + ) + + +# ───────────────────────────── DPT heads ─────────────────────────────────── + + +class PreActResidualConvUnit(nn.Module): + def __init__(self, features: int): + super().__init__() + self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False) + self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = F.relu(x) + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + return x + residual + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features: int, has_residual: bool = False, expand: bool = False): + super().__init__() + self.has_residual = has_residual + if has_residual: + self.residual_unit = PreActResidualConvUnit(features) + self.main_unit = PreActResidualConvUnit(features) + out_features = features // 2 if expand else features + self.out_conv = nn.Conv2d(features, out_features, 1, bias=True) + + def forward( + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.has_residual and residual is not None: + if residual.shape != x.shape: + residual = F.interpolate( + residual, size=x.shape[2:], mode="bilinear", align_corners=False + ) + residual = self.residual_unit(residual) + x = x + residual + x = self.main_unit(x) + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + x = self.out_conv(x) + return x + + +class ReassembleBlocks(nn.Module): + def __init__( + self, + input_embed_dim: int = 1024, + out_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + ): + super().__init__() + self.readout_type = readout_type + self.out_projections = nn.ModuleList( + [nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels] + ) + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2, padding=1), + ] + ) + if readout_type == "project": + self.readout_projects = nn.ModuleList( + [nn.Linear(2 * input_embed_dim, input_embed_dim) for _ in out_channels] + ) + + def forward(self, features): + out = [] + for i, (cls_token, x) in enumerate(features): + B, D, H, W = x.shape + if self.readout_type == "project": + x_flat = x.flatten(2).transpose(1, 2) + readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) + x_cat = torch.cat([x_flat, readout], dim=-1) + x_proj = F.gelu(self.readout_projects[i](x_cat)) + x = x_proj.transpose(1, 2).reshape(B, D, H, W) + x = self.out_projections[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +def _build_fusion_stack(channels: int) -> nn.ModuleList: + return nn.ModuleList( + [ + FeatureFusionBlock(channels, has_residual=False), + FeatureFusionBlock(channels, has_residual=True), + FeatureFusionBlock(channels, has_residual=True), + FeatureFusionBlock(channels, has_residual=True), + ] + ) + + +class _DPTHeadBase(nn.Module): + """Shared reassemble + fuse + project trunk used by all three task heads.""" + + def __init__( + self, + input_embed_dim: int, + channels: int, + post_process_channels: Tuple[int, ...], + readout_type: str, + ): + super().__init__() + self.reassemble = ReassembleBlocks( + input_embed_dim=input_embed_dim, + out_channels=post_process_channels, + readout_type=readout_type, + ) + self.convs = nn.ModuleList( + [ + nn.Conv2d(ch, channels, 3, padding=1, bias=False) + for ch in post_process_channels + ] + ) + self.fusion_blocks = _build_fusion_stack(channels) + self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) + self.gradient_checkpointing = False + + def gradient_checkpointing_enable(self, **_kwargs) -> None: + self.gradient_checkpointing = True + + def gradient_checkpointing_disable(self) -> None: + self.gradient_checkpointing = False + + def _ckpt(self, module, *args): + # Checkpoint only when grads are actually being tracked, so the no_grad + # target pass doesn't pay for wasted recompute. + if self.gradient_checkpointing and torch.is_grad_enabled(): + return torch.utils.checkpoint.checkpoint(module, *args, use_reentrant=False) + return module(*args) + + def _trunk(self, intermediate_features) -> torch.Tensor: + x = self.reassemble(intermediate_features) + x = [self.convs[i](feat) for i, feat in enumerate(x)] + out = self._ckpt(self.fusion_blocks[0], x[-1]) + for i in range(1, 4): + out = self._ckpt(self.fusion_blocks[i], out, x[-(i + 1)]) + return self.project(out) + + +class DPTDepthHead(_DPTHeadBase): + def __init__( + self, + input_embed_dim: int = 1024, + channels: int = 256, + post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + num_depth_bins: int = 256, + min_depth: float = 1e-3, + max_depth: float = 10.0, + ): + super().__init__(input_embed_dim, channels, post_process_channels, readout_type) + self.num_depth_bins = num_depth_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_head = nn.Linear(channels, num_depth_bins) + + def forward(self, intermediate_features, image_size=None) -> torch.Tensor: + out = F.relu(self._trunk(intermediate_features)) + out = out.permute(0, 2, 3, 1) + out = self.depth_head(out) + bin_centers = torch.linspace( + self.min_depth, + self.max_depth, + self.num_depth_bins, + device=out.device, + dtype=out.dtype, + ) + out = F.relu(out) + self.min_depth + out_norm = out / out.sum(dim=-1, keepdim=True) + depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers).unsqueeze(1) + if image_size is not None: + depth = F.interpolate( + depth, size=image_size, mode="bilinear", align_corners=False + ) + return depth + + +class DPTNormalsHead(_DPTHeadBase): + def __init__( + self, + input_embed_dim: int = 1024, + channels: int = 256, + post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + ): + super().__init__(input_embed_dim, channels, post_process_channels, readout_type) + self.normals_head = nn.Linear(channels, 3) + + def forward(self, intermediate_features, image_size=None) -> torch.Tensor: + out = self._trunk(intermediate_features) + out = out.permute(0, 2, 3, 1) + out = self.normals_head(out) + out = F.normalize(out, p=2, dim=-1) + out = out.permute(0, 3, 1, 2) + if image_size is not None: + out = F.interpolate( + out, size=image_size, mode="bilinear", align_corners=False + ) + return out + + +class DPTSegmentationHead(_DPTHeadBase): + def __init__( + self, + input_embed_dim: int = 1024, + channels: int = 256, + post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + num_classes: int = 150, + ): + super().__init__(input_embed_dim, channels, post_process_channels, readout_type) + self.segmentation_head = nn.Linear(channels, num_classes) + + def forward(self, intermediate_features, image_size=None) -> torch.Tensor: + out = self._trunk(intermediate_features) + out = out.permute(0, 2, 3, 1) + out = self.segmentation_head(out) + out = out.permute(0, 3, 1, 2) + if image_size is not None: + out = F.interpolate( + out, size=image_size, mode="bilinear", align_corners=False + ) + return out + + +# ───────────────────────────── Top-level model ───────────────────────────── + + +@dataclass +class TIPSv2DPTOutput: + depth: Optional[torch.Tensor] = None + normals: Optional[torch.Tensor] = None + segmentation: Optional[torch.Tensor] = None + + +# Hard-coded config for the b14-dpt variant — matches config.json on the hub. +_B14_DPT_CONFIG = dict( + backbone_repo="google/tipsv2-b14", + embed_dim=768, + channels=256, + post_process_channels=(96, 192, 384, 768), + block_indices=(2, 5, 8, 11), + readout_type="project", + num_depth_bins=256, + min_depth=1e-3, + max_depth=10.0, + num_seg_classes=150, + # Vision encoder + vision_fn="vit_base", + patch_size=14, + img_size=448, + init_values=1.0, + num_register_tokens=1, + ffn_layer="mlp", +) + + +class TIPSv2DPTModel(nn.Module): + """TIPSv2 DPT dense-prediction model (depth, normals, segmentation). + + Use :meth:`from_pretrained` to load weights for `google/tipsv2-b14-dpt`. + """ + + def __init__(self, config: Optional[dict] = None): + super().__init__() + cfg = dict(_B14_DPT_CONFIG) + if config: + cfg.update(config) + self.config = cfg + + builders = {"vit_base": _vit_base} + if cfg["vision_fn"] not in builders: + raise NotImplementedError(f"vision_fn={cfg['vision_fn']!r} not supported") + + self.vision_encoder = builders[cfg["vision_fn"]]( + img_size=cfg["img_size"], + patch_size=cfg["patch_size"], + ffn_layer=cfg["ffn_layer"], + init_values=cfg["init_values"], + interpolate_antialias=True, + interpolate_offset=0.0, + ) + + ppc = tuple(cfg["post_process_channels"]) + self.depth_head = DPTDepthHead( + input_embed_dim=cfg["embed_dim"], + channels=cfg["channels"], + post_process_channels=ppc, + readout_type=cfg["readout_type"], + num_depth_bins=cfg["num_depth_bins"], + min_depth=cfg["min_depth"], + max_depth=cfg["max_depth"], + ) + self.normals_head = DPTNormalsHead( + input_embed_dim=cfg["embed_dim"], + channels=cfg["channels"], + post_process_channels=ppc, + readout_type=cfg["readout_type"], + ) + self.segmentation_head = DPTSegmentationHead( + input_embed_dim=cfg["embed_dim"], + channels=cfg["channels"], + post_process_channels=ppc, + readout_type=cfg["readout_type"], + num_classes=cfg["num_seg_classes"], + ) + + # ---- properties + checkpointing --------------------------------------- + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def gradient_checkpointing_enable(self, **kwargs) -> None: + """Enable gradient checkpointing on the vision transformer blocks and DPT heads.""" + self.vision_encoder.gradient_checkpointing_enable(**kwargs) + for head in (self.depth_head, self.normals_head, self.segmentation_head): + head.gradient_checkpointing_enable(**kwargs) + + def gradient_checkpointing_disable(self) -> None: + self.vision_encoder.gradient_checkpointing_disable() + for head in (self.depth_head, self.normals_head, self.segmentation_head): + head.gradient_checkpointing_disable() + + enable_gradient_checkpointing = gradient_checkpointing_enable + disable_gradient_checkpointing = gradient_checkpointing_disable + + # ---- core inference path ---------------------------------------------- + + def _extract_intermediate(self, pixel_values: torch.Tensor): + intermediate = self.vision_encoder.get_intermediate_layers( + pixel_values, + n=tuple(self.config["block_indices"]), + reshape=True, + return_class_token=True, + norm=True, + ) + # Returned as (cls_token, patch_feats) tuples to match the remote API. + return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate] + + def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor: + h, w = pixel_values.shape[2:] + return self.depth_head( + self._extract_intermediate(pixel_values), image_size=(h, w) + ) + + def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor: + h, w = pixel_values.shape[2:] + return self.normals_head( + self._extract_intermediate(pixel_values), image_size=(h, w) + ) + + def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor: + h, w = pixel_values.shape[2:] + return self.segmentation_head( + self._extract_intermediate(pixel_values), image_size=(h, w) + ) + + def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput: + h, w = pixel_values.shape[2:] + feats = self._extract_intermediate(pixel_values) + return TIPSv2DPTOutput( + depth=self.depth_head(feats, image_size=(h, w)), + normals=self.normals_head(feats, image_size=(h, w)), + segmentation=self.segmentation_head(feats, image_size=(h, w)), + ) + + # ---- loader ----------------------------------------------------------- + + @classmethod + def from_pretrained( + cls, + model_id: str = "google/tipsv2-b14-dpt", + device: Union[str, torch.device] = "cpu", + dtype: torch.dtype = torch.float32, + cache_dir: Optional[str] = None, + ) -> "TIPSv2DPTModel": + """Build the model and load weights from the hub. + + Pulls the DPT head weights from ``model_id`` (default + ``google/tipsv2-b14-dpt``) and the vision-encoder weights from the + backbone repo specified in the DPT config. + """ + from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + + if model_id != "google/tipsv2-b14-dpt": + raise NotImplementedError( + f"Local TIPSv2DPTModel only supports 'google/tipsv2-b14-dpt'; got {model_id!r}" + ) + + model = cls() + + dpt_ckpt = hf_hub_download(model_id, "model.safetensors", cache_dir=cache_dir) + dpt_state = load_file(dpt_ckpt) + + backbone_ckpt = hf_hub_download( + model.config["backbone_repo"], + "model.safetensors", + cache_dir=cache_dir, + ) + backbone_state = load_file(backbone_ckpt) + # Backbone repo stores both vision and text encoders — keep only vision_encoder.*. + backbone_state = { + k: v for k, v in backbone_state.items() if k.startswith("vision_encoder.") + } + + merged = {**dpt_state, **backbone_state} + missing, unexpected = model.load_state_dict(merged, strict=False) + if missing: + print( + f"[tipsv2] Missing keys ({len(missing)}): {missing[:8]}{'...' if len(missing) > 8 else ''}" + ) + if unexpected: + print( + f"[tipsv2] Unexpected keys ({len(unexpected)}): {unexpected[:8]}{'...' if len(unexpected) > 8 else ''}" + ) + + model.to(device=device, dtype=dtype) + return model + + +# ─────────────────────── Vision-only encoder (no DPT) ─────────────────────── + + +_VISION_BUILDERS = {"vit_base": _vit_base, "vit_so400m": _vit_so400m} + + +@dataclass +class TIPSv2VisionOutput: + """HF-style vision output. ``hidden_states`` follows the ``output_hidden_states`` + convention: ``(embeddings, block_1_out, ..., block_L_out)``.""" + + last_hidden_state: torch.Tensor + pooler_output: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor, ...]] = None + + +class _VisionEncoderConfig: + def __init__(self, num_hidden_layers, num_register_tokens, patch_size, hidden_size): + self.num_hidden_layers = num_hidden_layers + self.num_register_tokens = num_register_tokens + self.patch_size = patch_size + self.hidden_size = hidden_size + + +class TIPSv2VisionModel(nn.Module): + """Vision-only TIPSv2 encoder exposing per-layer hidden states. + + Loads just the ``vision_encoder.*`` weights from a TIPSv2 repo (e.g. + ``google/tipsv2-so400m14``); the text encoder and DPT heads are never built. + Inputs are expected in ``[0, 1]`` — TIPSv2 applies no image normalization. + Used as a perceptual feature extractor (see TipsV2FE). + """ + + def __init__(self, config: dict): + super().__init__() + self.config = config + vision_fn = config.get("vision_fn", "vit_base") + if vision_fn not in _VISION_BUILDERS: + raise NotImplementedError(f"vision_fn={vision_fn!r} not supported") + self.vision_encoder = _VISION_BUILDERS[vision_fn]( + img_size=config.get("img_size", 448), + patch_size=config.get("patch_size", 14), + ffn_layer=config.get("ffn_layer", "mlp"), + init_values=config.get("init_values", 1.0), + interpolate_antialias=True, + interpolate_offset=0.0, + ) + self.vision_config = _VisionEncoderConfig( + num_hidden_layers=self.vision_encoder.n_blocks, + num_register_tokens=self.vision_encoder.num_register_tokens, + patch_size=self.vision_encoder.patch_size, + hidden_size=self.vision_encoder.embed_dim, + ) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def gradient_checkpointing_enable(self, **kwargs) -> None: + self.vision_encoder.gradient_checkpointing_enable(**kwargs) + + def gradient_checkpointing_disable(self) -> None: + self.vision_encoder.gradient_checkpointing_disable() + + enable_gradient_checkpointing = gradient_checkpointing_enable + disable_gradient_checkpointing = gradient_checkpointing_disable + + def forward( + self, pixel_values: torch.Tensor, output_hidden_states: bool = True + ) -> TIPSv2VisionOutput: + hidden_states = self.vision_encoder.forward_hidden_states(pixel_values) + last = self.vision_encoder.norm(hidden_states[-1]) + return TIPSv2VisionOutput( + last_hidden_state=last, + pooler_output=last[:, 0], + hidden_states=hidden_states, + ) + + @classmethod + def from_pretrained( + cls, + model_id: str = "google/tipsv2-so400m14", + device: Union[str, torch.device] = "cpu", + dtype: torch.dtype = torch.float32, + cache_dir: Optional[str] = None, + ) -> "TIPSv2VisionModel": + """Build the vision encoder and load its weights from the hub. + + Reads ``config.json`` to pick the vision architecture, then loads only the + ``vision_encoder.*`` tensors from ``model.safetensors``. + """ + import json + + from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + + config_path = hf_hub_download(model_id, "config.json", cache_dir=cache_dir) + with open(config_path) as f: + config = json.load(f) + + model = cls(config) + + ckpt = hf_hub_download(model_id, "model.safetensors", cache_dir=cache_dir) + state = load_file(ckpt) + # Repo stores vision + text encoders — keep only vision_encoder.*. + state = {k: v for k, v in state.items() if k.startswith("vision_encoder.")} + if not state: + raise RuntimeError(f"No vision_encoder weights found in {model_id}") + + missing, unexpected = model.load_state_dict(state, strict=False) + if missing: + print( + f"[tipsv2] Missing keys ({len(missing)}): {missing[:8]}{'...' if len(missing) > 8 else ''}" + ) + if unexpected: + print( + f"[tipsv2] Unexpected keys ({len(unexpected)}): {unexpected[:8]}{'...' if len(unexpected) > 8 else ''}" + ) + + model.to(device=device, dtype=dtype) + return model diff --git a/ai-toolkit/toolkit/models/vd_adapter.py b/ai-toolkit/toolkit/models/vd_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..4af8fe88da8fe990194bd1cdb9998af669c56fee --- /dev/null +++ b/ai-toolkit/toolkit/models/vd_adapter.py @@ -0,0 +1,826 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING, Optional +from collections import OrderedDict + +from diffusers import Transformer2DModel, FluxTransformer2DModel +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection +from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter +from transformers import SiglipImageProcessor, SiglipVisionModel +import traceback +from toolkit.config_modules import AdapterConfig + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + + +# matches distribution of randn +class Norm(nn.Module): + def __init__(self, target_mean=0.0, target_std=1.0, eps=1e-6): + super(Norm, self).__init__() + self.target_mean = target_mean + self.target_std = target_std + self.eps = eps + + def forward(self, x): + dims = tuple(range(1, x.dim())) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True) + + # Normalize + return self.target_std * (x - mean) / (std + self.eps) + self.target_mean + + +norm_layer = Norm() + +class SparseAutoencoder(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(SparseAutoencoder, self).__init__() + self.encoder = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, output_dim), + ) + self.norm = Norm() + self.decoder = nn.Sequential( + nn.Linear(output_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, input_dim), + ) + self.last_run = None + + def forward(self, x): + self.last_run = { + "input": x + } + x = self.encoder(x) + x = self.norm(x) + self.last_run["sparse"] = x + x = self.decoder(x) + x = self.norm(x) + self.last_run["output"] = x + return x + + +class MLPR(nn.Module): # MLP with reshaping + def __init__( + self, + in_dim, + in_channels, + out_dim, + out_channels, + use_residual=True + ): + super().__init__() + if use_residual: + assert in_dim == out_dim + # dont normalize if using conv + self.layer_norm = nn.LayerNorm(in_dim) + + self.fc1 = nn.Linear(in_dim, out_dim) + self.act_fn = nn.GELU() + self.conv1 = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x): + residual = x + x = self.layer_norm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.conv1(x) + return x + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class VisionDirectAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + try: + + adapter_hidden_states = self.conditional_embeds + if adapter_hidden_states.shape[0] == batch_size // 2: + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # if it is image embeds, we need to add a 1 dim at inx 1 + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + except Exception as e: + print("Error in VisionDirectAdapterAttnProcessor") + # print shapes of all tensors + print(f"hidden_states: {hidden_states.shape}") + print(f"adapter_hidden_states: {adapter_hidden_states.shape}") + print(f"vd_key: {vd_key.shape}") + print(f"vd_value: {vd_value.shape}") + print(f"vd_hidden_states: {vd_hidden_states.shape}") + print(f"query: {query.shape}") + print(f"key: {key.shape}") + print(f"value: {value.shape}") + print(f"inner_dim: {inner_dim}") + print(f"head_dim: {head_dim}") + print(f"batch_size: {batch_size}") + traceback.print_exc() + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CustomFluxVDAttnProcessor2_0(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.block_idx = block_idx + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # begin ip adapter + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds + block_scaler = self.adapter_ref().block_scaler + if block_scaler is not None: + # add 1 to block scaler so we can decay its weight to 1.0 + block_scaler = block_scaler[self.block_idx] + 1.0 + + if adapter_hidden_states.shape[0] < batch_size: + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # if it is image embeds, we need to add a 1 dim at inx 1 + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + # scale to block scaler + if block_scaler is not None: + orig_dtype = vd_hidden_states.dtype + if block_scaler.dtype != vd_hidden_states.dtype: + vd_hidden_states = vd_hidden_states.to(block_scaler.dtype) + vd_hidden_states = vd_hidden_states * block_scaler + if block_scaler.dtype != orig_dtype: + vd_hidden_states = vd_hidden_states.to(orig_dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + +class VisionDirectAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + vision_model: Union[CLIPVisionModelWithProjection], + ): + super(VisionDirectAdapter, self).__init__() + is_pixart = sd.is_pixart + is_flux = sd.is_flux + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.config: AdapterConfig = adapter.config + self.vision_model_ref: weakref.ref = weakref.ref(vision_model) + self.resampler = None + is_pixtral = self.config.image_encoder_arch == "pixtral" + + if adapter.config.clip_layer == "image_embeds": + if isinstance(vision_model, SiglipVisionModel): + self.token_size = vision_model.config.hidden_size + else: + self.token_size = vision_model.config.projection_dim + else: + self.token_size = vision_model.config.hidden_size + + self.mid_size = self.token_size + + if self.config.conv_pooling and self.config.conv_pooling_stacks > 1: + self.mid_size = self.mid_size * self.config.conv_pooling_stacks + + # if pixtral, use cross attn dim for more sparse representation if only doing double transformers + if is_pixtral and self.config.flux_only_double: + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + self.mid_size = hidden_size + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + elif is_flux: + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn") + + if not self.config.flux_only_double: + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + current_idx = 0 + + for name in attn_processor_keys: + if is_flux: + cross_attention_dim = None + else: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer") or name.startswith("single_transformer"): + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None and not is_flux: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux: + # is quantized + + to_k_adapter = torch.randn(hidden_size, hidden_size) * 0.01 + to_v_adapter = torch.randn(hidden_size, hidden_size) * 0.01 + to_k_adapter = to_k_adapter.to(self.sd_ref().torch_dtype) + to_v_adapter = to_v_adapter.to(self.sd_ref().torch_dtype) + else: + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.mid_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.mid_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.mid_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + elif to_k_adapter.shape[1] > self.mid_size: + to_k_adapter = to_k_adapter[:, :self.mid_size] + to_v_adapter = to_v_adapter[:, :self.mid_size] + # if is_pixart: + # to_k_bias = to_k_bias[:self.mid_size] + # to_v_bias = to_v_bias[:self.mid_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + # if is_pixart: + # to_k_bias = to_k_bias + # to_v_bias = to_v_bias + + weights = { + "to_k_adapter.weight": to_k_adapter * 0.01, + "to_v_adapter.weight": to_v_adapter * 0.01, + } + # if is_pixart: + # weights["to_k_adapter.bias"] = to_k_bias + # weights["to_v_adapter.bias"] = to_v_bias\ + + if is_flux: + attn_procs[name] = CustomFluxVDAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.mid_size, + has_bias=False, + block_idx=current_idx + ) + else: + attn_procs[name] = VisionDirectAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.mid_size, + has_bias=False, + ) + current_idx += 1 + attn_procs[name].load_state_dict(weights) + + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList([ + transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) + ]) + elif self.sd_ref().is_flux: + # we have to set them ourselves + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + + if not self.config.flux_only_double: + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + + if not self.config.flux_only_double: + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) + else: + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + ) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + num_modules = len(self.adapter_modules) + if self.config.train_scaler: + self.block_scaler = torch.nn.Parameter(torch.tensor([0.0] * num_modules).to( + dtype=torch.float32, + device=self.sd_ref().device_torch + )) + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + self.block_scaler.requires_grad = True + else: + self.block_scaler = None + + self.pool = None + + if self.config.num_tokens is not None: + # image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + # max_seq_len = 257 + # if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # # clip + # max_seq_len = int( + # image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + # self.resampler = MLPR( + # in_dim=self.token_size, + # in_channels=max_seq_len, + # out_dim=self.mid_size, + # out_channels=self.config.num_tokens, + # ) + vision_config = self.adapter_ref().vision_encoder.config + # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1) + # siglip doesnt add 1 + sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2) + self.pool = nn.Sequential( + nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False), + Norm(), + ) + + elif self.config.image_encoder_arch == "pixtral": + self.resampler = VisionLanguageAdapter( + in_dim=self.token_size, + out_dim=self.mid_size, + ) + + self.sparse_autoencoder = None + if self.config.conv_pooling: + vision_config = self.adapter_ref().vision_encoder.config + # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1) + # siglip doesnt add 1 + sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2) + self.pool = nn.Sequential( + nn.Conv1d(sequence_length, self.config.conv_pooling_stacks, 1, bias=False), + Norm(), + ) + if self.config.sparse_autoencoder_dim is not None: + hidden_dim = self.token_size * 2 + if hidden_dim > self.config.sparse_autoencoder_dim: + hidden_dim = self.config.sparse_autoencoder_dim + self.sparse_autoencoder = SparseAutoencoder( + input_dim=self.token_size, + hidden_dim=hidden_dim, + output_dim=self.config.sparse_autoencoder_dim + ) + + if self.config.clip_layer == "image_embeds": + self.proj = nn.Linear(self.token_size, self.token_size) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.config.train_scaler: + # only return the block scaler + if destination is None: + destination = OrderedDict() + destination[prefix + 'block_scaler'] = self.block_scaler + return destination + return super().state_dict(destination, prefix, keep_vars) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, input): + # block scaler keeps moving dtypes. make sure it is float32 here + # todo remove this when we have a real solution + + if self.block_scaler is not None and self.block_scaler.dtype != torch.float32: + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + # if doing image_embeds, normalize here + if self.config.clip_layer == "image_embeds": + input = norm_layer(input) + input = self.proj(input) + if self.resampler is not None: + input = self.resampler(input) + if self.pool is not None: + input = self.pool(input) + if self.config.conv_pooling_stacks > 1: + input = torch.cat(torch.chunk(input, self.config.conv_pooling_stacks, dim=1), dim=2) + if self.sparse_autoencoder is not None: + input = self.sparse_autoencoder(input) + return input + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + if self.block_scaler is not None: + if self.block_scaler.dtype != torch.float32: + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + return self + + def post_weight_update(self): + # force block scaler to be mean of 1 + pass diff --git a/ai-toolkit/toolkit/models/wan21/__init__.py b/ai-toolkit/toolkit/models/wan21/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2706a1189e79f2d53915745c92ac4cd17e086d --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/__init__.py @@ -0,0 +1,2 @@ +from .wan21 import Wan21 +from .wan21_i2v import Wan21I2V \ No newline at end of file diff --git a/ai-toolkit/toolkit/models/wan21/autoencoder_kl_wan.py b/ai-toolkit/toolkit/models/wan21/autoencoder_kl_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..8761eea2ec53fc8526b3f6b1edfe4cb0bc297704 --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/autoencoder_kl_wan.py @@ -0,0 +1,1410 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.activations import get_activation +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +import copy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + +class WanCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class WanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class WanUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class WanResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class WanAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class WanResidualDownBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache, feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + +class WanEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + in_channels: int = 3, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + use_ckpt = torch.is_grad_enabled() and self.gradient_checkpointing and feat_cache is None + + ## downsamples + for layer in self.down_blocks: + if use_ckpt: + x = self._gradient_checkpointing_func(layer, x) + elif feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + if use_ckpt: + x = self._gradient_checkpointing_func(self.mid_block, x) + else: + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache, feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class WanDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + + # init block + self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0 and not is_residual: + # wan vae 2.1 + in_dim = in_dim // 2 + + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None + upsample_mode = None + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + # Create and add the upsampling block + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag= up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + use_ckpt = torch.is_grad_enabled() and self.gradient_checkpointing and feat_cache is None + + ## middle + if use_ckpt: + x = self._gradient_checkpointing_func(self.mid_block, x) + else: + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + if use_ckpt: + x = self._gradient_checkpointing_func(up_block, x, None, [0], first_chunk) + else: + x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +def patchify(x, patch_size): + # YiYi TODO: refactor this + from einops import rearrange + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + # YiYi TODO: refactor this + from einops import rearrange + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + +class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + base_dim: int = 96, + decoder_base_dim: Optional[int] = None, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: Optional[int] = None, + scale_factor_temporal: Optional[int] = 4, + scale_factor_spatial: Optional[int] = 8, + clip_output: bool = True, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + if decoder_base_dim is None: + decoder_base_dim = base_dim + + self.encoder = WanEncoder3d( + in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual + ) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts["decoder"] + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = self._cached_conv_counts["encoder"] + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + if self.config.clip_output: + out = torch.clamp(out, min=-1.0, max=1.0) + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec \ No newline at end of file diff --git a/ai-toolkit/toolkit/models/wan21/wan21.py b/ai-toolkit/toolkit/models/wan21/wan21.py new file mode 100644 index 0000000000000000000000000000000000000000..8519b7d01b1bc2f59a7b5f42aef3e336ab9e9b7e --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/wan21.py @@ -0,0 +1,720 @@ +# WIP, coming soon ish +from functools import partial +import torch +import yaml +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.memory_management.manager import MemoryManager +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds +from transformers import AutoTokenizer, UMT5EncoderModel +from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKL +from .autoencoder_kl_wan import AutoencoderKLWan +import os +import sys + +import weakref +import torch +import yaml + +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +import torch +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize, get_qtype +from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from typing import TYPE_CHECKING, List +from toolkit.accelerator import unwrap_model +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from tqdm import tqdm +import torch.nn.functional as F +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE +# from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from typing import Any, Callable, Dict, List, Optional, Union +from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original +from toolkit.util.quantize import quantize_model +from toolkit.models.loaders.umt5 import get_umt5_encoder + +# for generation only? +scheduler_configUniPC = { + "_class_name": "UniPCMultistepScheduler", + "_diffusers_version": "0.33.0.dev0", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "disable_corrector": [], + "dynamic_thresholding_ratio": 0.995, + "final_sigmas_type": "zero", + "flow_shift": 3.0, + "lower_order_final": True, + "num_train_timesteps": 1000, + "predict_x0": True, + "prediction_type": "flow_prediction", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "solver_order": 2, + "solver_p": None, + "solver_type": "bh2", + "steps_offset": 0, + "thresholding": False, + "timestep_spacing": "linspace", + "trained_betas": None, + "use_beta_sigmas": False, + "use_exponential_sigmas": False, + "use_flow_sigmas": True, + "use_karras_sigmas": False +} + +# for training. I think it is right +scheduler_config = { + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": False +} + + +class AggressiveWanUnloadPipeline(WanPipeline): + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + device: torch.device = torch.device("cuda"), + ): + super().__init__( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + transformer_2=transformer_2, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + vae=vae, + scheduler=scheduler, + ) + self._exec_device = device + @property + def _execution_device(self): + return self._exec_device + + def __call__( + self: WanPipeline, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], + PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # unload vae and transformer + vae_device = self.vae.device + transformer_device = self.transformer.device + text_encoder_device = self.text_encoder.device + device = self.transformer.device + + print("Unloading vae") + self.vae.to("cpu") + self.text_encoder.to(device) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # unload text encoder + print("Unloading text encoder") + self.text_encoder.to("cpu") + + self.transformer.to(device) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(device, transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + device, transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(device, transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * \ + (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # unload transformer + # load vae + print("Loading Vae") + self.vae.to(vae_device) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video( + video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + + +class Wan21(BaseModel): + arch = 'wan21' + _wan_generation_scheduler_config = scheduler_configUniPC + _wan_expand_timesteps = False + _wan_vae_path = None + + _comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors'] + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['WanTransformer3DModel'] + + # cache for holding noise + self.effective_noise = None + + def get_bucket_divisibility(self): + return 16 + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def load_wan_transformer(self, transformer_path, subfolder=None): + self.print_and_status_update("Loading transformer") + dtype = self.torch_dtype + transformer = WanTransformer3DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ).to(dtype=dtype) + + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for Wan2.1 models") + + if self.model_config.low_vram: + # quantize on the device + transformer.to('cpu', dtype=dtype) + flush() + else: + transformer.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError( + "Assistant LoRA is not supported for Wan2.1 models currently") + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for Wan2.1 models currently") + + flush() + + if self.model_config.quantize: + self.print_and_status_update("Quantizing Transformer") + quantize_model(self, transformer) + flush() + + if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0: + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent + ) + + if self.model_config.low_vram: + self.print_and_status_update("Moving transformer to CPU") + transformer.to('cpu') + + return transformer + + def load_model(self): + dtype = self.torch_dtype + model_path = self.model_config.name_or_path + + self.print_and_status_update("Loading Wan model") + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + + te_path = "ai-toolkit/umt5_xxl_encoder" + if os.path.exists(os.path.join(model_path, 'text_encoder')): + te_path = model_path + + vae_path = self.model_config.extras_name_or_path + if os.path.exists(os.path.join(model_path, 'vae')): + vae_path = model_path + + transformer = self.load_wan_transformer( + transformer_path, + subfolder=subfolder, + ) + + flush() + + self.print_and_status_update("Loading UMT5EncoderModel") + + tokenizer, text_encoder = get_umt5_encoder( + model_path=te_path, + tokenizer_subfolder="tokenizer", + encoder_subfolder="text_encoder", + torch_dtype=dtype, + comfy_files=self._comfy_te_file + ) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing UMT5EncoderModel") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0: + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent + ) + + if self.model_config.low_vram: + print("Moving transformer back to GPU") + # we can move it back to the gpu now + transformer.to(self.device_torch) + + scheduler = Wan21.get_train_scheduler() + self.print_and_status_update("Loading VAE") + # todo, example does float 32? check if quality suffers + + if self._wan_vae_path is not None: + # load the vae from individual repo + vae = AutoencoderKLWan.from_pretrained( + self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype) + else: + vae = AutoencoderKLWan.from_pretrained( + vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) + flush() + + self.print_and_status_update("Making pipe") + pipe: WanPipeline = WanPipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + self.pipeline = pipe + self.model = transformer + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + def get_generation_pipeline(self): + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() + if self.model_config.low_vram: + pipeline = AggressiveWanUnloadPipeline( + vae=self.vae, + transformer=self.model, + transformer_2=self.model, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + expand_timesteps=self._wan_expand_timesteps, + device=self.device_torch + ) + else: + pipeline = WanPipeline( + vae=self.vae, + transformer=self.unet, + transformer_2=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + expand_timesteps=self._wan_expand_timesteps, + scheduler=scheduler, + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: WanPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + pipeline = pipeline.to(self.device_torch) + # todo, figure out how to do video + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + # vae_scale_factor_spatial = 8 + # vae_scale_factor_temporal = 4 + # num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # shape = ( + # batch_size, + # num_channels_latents, # 16 + # num_latent_frames, # 81 + # int(height) // self.vae_scale_factor_spatial, + # int(width) // self.vae_scale_factor_spatial, + # ) + + noise_pred = self.model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + return_dict=False, + **kwargs + )[0] + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + max_sequence_length=512, + device=self.device_torch, + dtype=self.torch_dtype, + ) + return PromptEmbeds(prompt_embeds) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.vae.device == torch.device('cpu'): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + + image_list = [image.to(device, dtype=dtype) for image in image_list] + + # Normalize shapes + norm_images = [] + for image in image_list: + if image.ndim == 3: + # (C, H, W) -> (C, 1, H, W) + norm_images.append(image.unsqueeze(1)) + elif image.ndim == 4: + # (T, C, H, W) -> (C, T, H, W) + norm_images.append(image.permute(1, 0, 2, 3)) + else: + raise ValueError(f"Invalid image shape: {image.shape}") + + # Stack to (B, C, T, H, W) + images = torch.stack(norm_images) + B, C, T, H, W = images.shape + + # Resize if needed (B * T, C, H, W) + if H % 8 != 0 or W % 8 != 0: + target_h = H // 8 * 8 + target_w = W // 8 * 8 + images = images.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) + images = F.interpolate(images, size=(target_h, target_w), mode='bilinear', align_corners=False) + images = images.view(B, T, C, target_h, target_w).permute(0, 2, 1, 3, 4) + + latents = self.vae.encode(images).latent_dist.sample() + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = (latents - latents_mean) * latents_std + + return latents.to(device, dtype=dtype) + + def decode_latents(self, latents: torch.Tensor, device=None, dtype=None): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + if self.vae.device == torch.device('cpu'): + self.vae.to(device) + + latents = latents.to(device, dtype=dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + + images = self.vae.decode(latents).sample + + return images.to(device, dtype=dtype) + + def get_model_has_grad(self): + return False + + def get_te_has_grad(self): + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: Wan21 = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + if batch is None: + raise ValueError("Batch is not provided") + if noise is None: + raise ValueError("Noise is not provided") + return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + return convert_to_original(state_dict) + + def convert_lora_weights_before_load(self, state_dict): + return convert_to_diffusers(state_dict) + + def get_base_model_version(self): + return "wan_2.1" + + def get_transformer_block_names(self): + return ['blocks'] diff --git a/ai-toolkit/toolkit/models/wan21/wan21_i2v.py b/ai-toolkit/toolkit/models/wan21/wan21_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe012f84d32828ee486e061905539ca61f26856 --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/wan21_i2v.py @@ -0,0 +1,524 @@ +# WIP, coming soon ish +from functools import partial +import torch +import yaml +from toolkit.accelerator import unwrap_model +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.prompt_utils import PromptEmbeds +from transformers import AutoTokenizer, UMT5EncoderModel +from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanTransformer3DModel +import os +import sys + +import weakref +import torch +import yaml +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig +import torch +from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from transformers import CLIPVisionModel, CLIPImageProcessor +import torch.nn.functional as F + +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.video_processor import VideoProcessor +from diffusers.image_processor import PipelineImageInput +from PIL import Image + +from .wan21 import \ + scheduler_configUniPC, \ + scheduler_config, \ + Wan21 + +from .wan_utils import add_first_frame_conditioning + + +class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer_2: WanTransformer3DModel = None, + boundary_ratio: Optional[float] = None, + device: torch.device = torch.device("cuda"), + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + transformer_2=transformer_2, + boundary_ratio=boundary_ratio, + ) + self._exec_device = device + + @property + def _execution_device(self): + return self._exec_device + + @torch.no_grad() + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # unload vae and transformer + # device = self.transformer.device + device = self._exec_device + + self.text_encoder.to(device) + + self.vae.to('cpu') + self.image_encoder.to('cpu') + flush() + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds=None, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # unload text encoder + self.text_encoder.to("cpu") + self.transformer.to(device) + flush() + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + self.image_encoder.to(device) + self.vae.to(device) + image_embeds = self.encode_image(image) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.bfloat16, + device, + generator, + latents, + ) + self.image_encoder.to('cpu') + self.vae.to('cpu') + flush() + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, # todo I think unconditional should be scaled down version + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + self.vae.to(device) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + + def encode_image(self, image: PipelineImageInput): + image = self.image_processor(images=image, return_tensors="pt") + image = {k: v.to(self.image_encoder.device, dtype=self.image_encoder.dtype) for k, v in image.items()} + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + + + + +class Wan21I2V(Wan21): + arch = 'wan21_i2v' + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['WanTransformer3DModel'] + self.image_encoder: CLIPVisionModel = None + self.image_processor: CLIPImageProcessor = None + + def load_model(self): + # call the super class to load most of the model + super().load_model() + if self.model_config.low_vram: + # unload text encoder + self.text_encoder.to("cpu") + # all the base stuff is loaded. We now need to load the vision encoder stuff + dtype = self.torch_dtype + try: + self.image_processor = CLIPImageProcessor.from_pretrained( + self.model_config.extras_name_or_path , + subfolder="image_processor" + ) + self.image_encoder = CLIPVisionModel.from_pretrained( + self.model_config.extras_name_or_path, + subfolder="image_encoder", + torch_dtype=dtype, + ) + except Exception as e: + # load from name_or_path + self.image_processor = CLIPImageProcessor.from_pretrained( + self.model_config.name_or_path_original, + subfolder="image_processor" + ) + self.image_encoder = CLIPVisionModel.from_pretrained( + self.model_config.name_or_path_original, + subfolder="image_encoder", + torch_dtype=dtype, + ) + self.image_encoder.to(self.device_torch, dtype=dtype) + self.image_encoder.eval() + self.image_encoder.requires_grad_(False) + + if self.model_config.low_vram: + # unload image encoder + self.image_encoder.to("cpu") + + # rebuild the pipeline + self.pipeline = self.get_generation_pipeline() + flush() + + def generate_images( + self, + image_configs, + sampler=None, + pipeline=None, + ): + # will oom on 24gb vram if we dont unload vision encoder first + if self.model_config.low_vram: + # unload image encoder + self.image_encoder.to("cpu") + self.vae.to("cpu") + self.transformer.to("cpu") + flush() + super().generate_images( + image_configs, + sampler=sampler, + pipeline=pipeline, + ) + + def set_device_state_preset(self, *args, **kwargs): + # set the device state to cpu for the image encoder + if self.model_config.low_vram: + return + super().set_device_state_preset(*args, **kwargs) + + + def get_generation_pipeline(self): + # todo unipc got broken in a diffusers update. Use euler for now. + # scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) + scheduler = self.get_train_scheduler() + if self.model_config.low_vram: + pipeline = AggressiveWanI2VUnloadPipeline( + vae=self.vae, + transformer=self.model, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + image_encoder=self.image_encoder, + image_processor=self.image_processor, + device=self.device_torch + ) + else: + pipeline = WanImageToVideoPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + image_encoder=self.image_encoder, + image_processor=self.image_processor, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: WanImageToVideoPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + # pipeline = pipeline.to(self.device_torch) + + + if gen_config.ctrl_img is None: + raise ValueError("I2V samples must have a control image") + + control_img = Image.open(gen_config.ctrl_img).convert("RGB") + + height = gen_config.height + width = gen_config.width + + # make sure they are divisible by 16 + height = height // 16 * 16 + width = width // 16 * 16 + + # resize the control image + control_img = control_img.resize((width, height), Image.LANCZOS) + + output = pipeline( + image=control_img, + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=height, + width=width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img + + + def preprocess_clip_image(self, image_n1p1): + # tensor shape: (bs, ch, height, width) with values in range [-1, 1] + # Convert from [-1, 1] to [0, 1] range + tensor = (image_n1p1 + 1) / 2 + + # Resize to 224x224 (using bilinear interpolation, which is resample=3 in PIL) + if tensor.shape[2] != 224 or tensor.shape[3] != 224: + tensor = F.interpolate(tensor, size=(224, 224), mode='bilinear', align_corners=False) + + tensors_0_1 = tensor.clamp(0, 1) # Ensure values are in [0, 1] range + + mean = torch.tensor(self.image_processor.image_mean).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + std = torch.tensor(self.image_processor.image_std).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean) / std + + return clip_image.detach() + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: DataLoaderBatchDTO, + **kwargs + ): + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + + # first_frames shape is (bs, channels, height, width), -1 to 1 + preprocessed_frames = self.preprocess_clip_image(first_frames) + preprocessed_frames = preprocessed_frames.to(self.device_torch, dtype=self.torch_dtype) + # preprocessed_frame shape is (bs, 3, 224, 224) + self.image_encoder.to(self.device_torch) + image_embeds_full = self.image_encoder(preprocessed_frames, output_hidden_states=True) + image_embeds = image_embeds_full.hidden_states[-2] + image_embeds = image_embeds.to(self.device_torch, dtype=self.torch_dtype) + + # Add conditioning using the standalone function + conditioned_latent = add_first_frame_conditioning( + latent_model_input=latent_model_input, + first_frame=first_frames, + vae=self.vae + ) + + noise_pred = self.model( + hidden_states=conditioned_latent, + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + **kwargs + )[0] + return noise_pred \ No newline at end of file diff --git a/ai-toolkit/toolkit/models/wan21/wan_attn.py b/ai-toolkit/toolkit/models/wan21/wan_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd93e1c0534c05e6f7953b2be739639c54779f8 --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/wan_attn.py @@ -0,0 +1,84 @@ +import torch +import torch.nn.functional as F +from typing import Optional +from diffusers.models.attention_processor import Attention + + +# modified to set the image embedder size +class WanAttnProcessor2_0: + def __init__(self, num_img_tokens: int = 257): + self.num_img_tokens = num_img_tokens + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, + :self.num_img_tokens] + encoder_hidden_states = encoder_hidden_states[:, + self.num_img_tokens:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex( + hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten( + 2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states diff --git a/ai-toolkit/toolkit/models/wan21/wan_lora_convert.py b/ai-toolkit/toolkit/models/wan21/wan_lora_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..44001f1731c02c5845fcabc5bcf73950c68ac08d --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/wan_lora_convert.py @@ -0,0 +1,75 @@ +def convert_to_diffusers(state_dict): + new_state_dict = {} + for key in state_dict: + new_key = key + # Base model name change + if key.startswith("diffusion_model."): + new_key = key.replace("diffusion_model.", "transformer.") + + # Attention blocks conversion + if "self_attn" in new_key: + new_key = new_key.replace("self_attn", "attn1") + elif "cross_attn" in new_key: + new_key = new_key.replace("cross_attn", "attn2") + + # Attention components conversion + parts = new_key.split(".") + for i, part in enumerate(parts): + if part in ["q", "k", "v"]: + parts[i] = f"to_{part}" + elif part == "k_img": + parts[i] = "add_k_proj" + elif part == "v_img": + parts[i] = "add_v_proj" + elif part == "o": + parts[i] = "to_out.0" + new_key = ".".join(parts) + + # FFN conversion + if "ffn.0" in new_key: + new_key = new_key.replace("ffn.0", "ffn.net.0.proj") + elif "ffn.2" in new_key: + new_key = new_key.replace("ffn.2", "ffn.net.2") + + new_state_dict[new_key] = state_dict[key] + return new_state_dict + + +def convert_to_original(state_dict): + new_state_dict = {} + for key in state_dict: + new_key = key + # Base model name change + if key.startswith("transformer."): + new_key = key.replace("transformer.", "diffusion_model.") + + # Attention blocks conversion + if "attn1" in new_key: + new_key = new_key.replace("attn1", "self_attn") + elif "attn2" in new_key: + new_key = new_key.replace("attn2", "cross_attn") + + # Attention components conversion + if "to_out.0" in new_key: + new_key = new_key.replace("to_out.0", "o") + elif "to_q" in new_key: + new_key = new_key.replace("to_q", "q") + elif "to_k" in new_key: + new_key = new_key.replace("to_k", "k") + elif "to_v" in new_key: + new_key = new_key.replace("to_v", "v") + + # img attn projection + elif "add_k_proj" in new_key: + new_key = new_key.replace("add_k_proj", "k_img") + elif "add_v_proj" in new_key: + new_key = new_key.replace("add_v_proj", "v_img") + + # FFN conversion + if "ffn.net.0.proj" in new_key: + new_key = new_key.replace("ffn.net.0.proj", "ffn.0") + elif "ffn.net.2" in new_key: + new_key = new_key.replace("ffn.net.2", "ffn.2") + + new_state_dict[new_key] = state_dict[key] + return new_state_dict diff --git a/ai-toolkit/toolkit/models/wan21/wan_utils.py b/ai-toolkit/toolkit/models/wan21/wan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6755007a337f3ffee480b7a33d1c95002fe61ad6 --- /dev/null +++ b/ai-toolkit/toolkit/models/wan21/wan_utils.py @@ -0,0 +1,177 @@ +import torch +import torch.nn.functional as F + + +def add_first_frame_conditioning( + latent_model_input, + first_frame, + vae +): + """ + Adds first frame conditioning to a video diffusion model input. + + Args: + latent_model_input: Original latent input (bs, channels, num_frames, height, width) + first_frame: Tensor of first frame to condition on (bs, channels, height, width) + vae: VAE model for encoding the conditioning + + Returns: + conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width) + """ + device = latent_model_input.device + dtype = latent_model_input.dtype + vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample) + + # Get number of frames from latent model input + _, _, num_latent_frames, _, _ = latent_model_input.shape + + # Calculate original number of frames + # For n original frames, there are (n-1)//4 + 1 latent frames + # So to get n: n = (num_latent_frames-1)*4 + 1 + num_frames = (num_latent_frames - 1) * 4 + 1 + + if len(first_frame.shape) == 3: + # we have a single image + first_frame = first_frame.unsqueeze(0) + + # if it doesnt match the batch size, we need to expand it + if first_frame.shape[0] != latent_model_input.shape[0]: + first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) + + # resize first frame to match the latent model input + vae_scale_factor = vae.config.scale_factor_spatial + first_frame = F.interpolate( + first_frame, + size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor), + mode='bilinear', + align_corners=False + ) + + # Add temporal dimension to first frame + first_frame = first_frame.unsqueeze(2) + + # Create video condition with first frame and zeros for remaining frames + zero_frame = torch.zeros_like(first_frame) + video_condition = torch.cat([ + first_frame, + *[zero_frame for _ in range(num_frames - 1)] + ], dim=2) + + # Prepare for VAE encoding (bs, channels, num_frames, height, width) + # video_condition = video_condition.permute(0, 2, 1, 3, 4) + + # Encode with VAE + latent_condition = vae.encode( + video_condition.to(device, dtype) + ).latent_dist.sample() + latent_condition = latent_condition.to(device, dtype) + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + + + # Create mask: 1 for conditioning frames, 0 for frames to generate + batch_size = first_frame.shape[0] + latent_height = latent_condition.shape[3] + latent_width = latent_condition.shape[4] + + # Initialize mask for all frames + mask_lat_size = torch.ones( + batch_size, 1, num_frames, latent_height, latent_width) + + # Set all non-first frames to 0 + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + + # Special handling for first frame + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=vae_scale_factor_temporal) + + # Combine first frame mask with rest + mask_lat_size = torch.concat( + [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + + # Reshape and transpose for model input + mask_lat_size = mask_lat_size.view( + batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(device, dtype) + + # Combine conditioning with latent input + first_frame_condition = torch.concat( + [mask_lat_size, latent_condition], dim=1) + conditioned_latent = torch.cat( + [latent_model_input, first_frame_condition], dim=1) + + return conditioned_latent + + +def add_first_frame_conditioning_v22( + latent_model_input, + first_frame, + vae, + last_frame=None +): + """ + Overwrites first few time steps in latent_model_input with VAE-encoded first_frame, + and returns the modified latent + binary mask (0=conditioned, 1=noise). + + Args: + latent_model_input: torch.Tensor of shape (bs, 48, T, H, W) + first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale) + vae: VAE model with .encode() and .config.latents_mean/std + + Returns: + latent: (bs, 48, T, H, W) - modified input latent + mask: (bs, 1, T, H, W) - binary mask + """ + device = latent_model_input.device + dtype = latent_model_input.dtype + bs, _, T, H, W = latent_model_input.shape + scale = vae.config.scale_factor_spatial + target_h = H * scale + target_w = W * scale + + # Ensure shape + if first_frame.ndim == 3: + first_frame = first_frame.unsqueeze(0) + if first_frame.shape[0] != bs: + first_frame = first_frame.expand(bs, -1, -1, -1) + + # Resize and encode + first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False) + first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W) + encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device) + + # Normalize + mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype) + std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype) + encoded = (encoded - mean) * std + + # Replace in latent + latent = latent_model_input.clone() + latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0] + + # Mask: 0 where conditioned, 1 otherwise + mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype) + mask[:, :, :encoded.shape[2]] = 0.0 + + if last_frame is not None: + # If last_frame is provided, encode it similarly + last_frame_up = F.interpolate(last_frame, size=(target_h, target_w), mode="bilinear", align_corners=False) + last_frame_up = last_frame_up.unsqueeze(2) + last_encoded = vae.encode(last_frame_up).latent_dist.sample().to(dtype).to(device) + last_encoded = (last_encoded - mean) * std + latent[:, :, -last_encoded.shape[2]:] = last_encoded # replace last + mask[:, :, -last_encoded.shape[2]:] = 0.0 # + # Ensure mask is still binary + mask = mask.clamp(0.0, 1.0) + + return latent, mask \ No newline at end of file diff --git a/ai-toolkit/toolkit/models/zipper_resampler.py b/ai-toolkit/toolkit/models/zipper_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..35f018b09bd49e802a9a26c225890412706bb1c8 --- /dev/null +++ b/ai-toolkit/toolkit/models/zipper_resampler.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn + + +class ContextualAlphaMask(nn.Module): + def __init__( + self, + dim: int = 768, + ): + super(ContextualAlphaMask, self).__init__() + self.dim = dim + + half_dim = dim // 2 + quarter_dim = dim // 4 + + self.fc1 = nn.Linear(self.dim, self.dim) + self.fc2 = nn.Linear(self.dim, half_dim) + self.norm1 = nn.LayerNorm(half_dim) + self.fc3 = nn.Linear(half_dim, half_dim) + self.fc4 = nn.Linear(half_dim, quarter_dim) + self.norm2 = nn.LayerNorm(quarter_dim) + self.fc5 = nn.Linear(quarter_dim, quarter_dim) + self.fc6 = nn.Linear(quarter_dim, 1) + # set fc6 weights to near zero + self.fc6.weight.data.normal_(mean=0.0, std=0.0001) + self.act_fn = nn.GELU() + + def forward(self, x): + # x = (batch_size, 77, 768) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.norm1(x) + x = self.act_fn(x) + x = self.fc3(x) + x = self.act_fn(x) + x = self.fc4(x) + x = self.norm2(x) + x = self.act_fn(x) + x = self.fc5(x) + x = self.act_fn(x) + x = self.fc6(x) + x = torch.sigmoid(x) + return x + + +class ZipperModule(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + use_residual=False, + ): + super().__init__() + self.in_size = in_size + self.in_tokens = in_tokens + self.out_size = out_size + self.out_tokens = out_tokens + self.hidden_size = hidden_size + self.hidden_tokens = hidden_tokens + self.use_residual = use_residual + + self.act_fn = nn.GELU() + self.layernorm = nn.LayerNorm(self.in_size) + + self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1) + # act + self.fc1 = nn.Linear(self.in_size, self.hidden_size) + # act + self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1) + # act + self.fc2 = nn.Linear(self.hidden_size, self.out_size) + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.conv1(x) + x = self.act_fn(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.conv2(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class ZipperResampler(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + num_blocks=1, + is_conv_input=False, + ): + super().__init__() + self.is_conv_input = is_conv_input + + module_list = [] + for i in range(num_blocks): + + this_in_size = in_size + this_in_tokens = in_tokens + this_out_size = out_size + this_out_tokens = out_tokens + this_hidden_size = hidden_size + this_hidden_tokens = hidden_tokens + use_residual = False + + # maintain middle sizes as hidden_size + if i == 0: # first block + this_in_size = in_size + this_in_tokens = in_tokens + if num_blocks == 1: + this_out_size = out_size + this_out_tokens = out_tokens + else: + this_out_size = hidden_size + this_out_tokens = hidden_tokens + elif i == num_blocks - 1: # last block + this_out_size = out_size + this_out_tokens = out_tokens + if num_blocks == 1: + this_in_size = in_size + this_in_tokens = in_tokens + else: + this_in_size = hidden_size + this_in_tokens = hidden_tokens + else: # middle blocks + this_out_size = hidden_size + this_out_tokens = hidden_tokens + this_in_size = hidden_size + this_in_tokens = hidden_tokens + use_residual = True + + module_list.append(ZipperModule( + in_size=this_in_size, + in_tokens=this_in_tokens, + out_size=this_out_size, + out_tokens=this_out_tokens, + hidden_size=this_hidden_size, + hidden_tokens=this_hidden_tokens, + use_residual=use_residual + )) + + self.blocks = nn.ModuleList(module_list) + + self.ctx_alpha = ContextualAlphaMask( + dim=out_size, + ) + + def forward(self, x): + if self.is_conv_input: + # flatten + x = x.view(x.size(0), x.size(1), -1) + # rearrange to (batch, tokens, size) + x = x.permute(0, 2, 1) + + for block in self.blocks: + x = block(x) + alpha = self.ctx_alpha(x) + return x * alpha diff --git a/ai-toolkit/toolkit/network_mixins.py b/ai-toolkit/toolkit/network_mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..2012b5c335c0ed04dac20bc23f9d505b328d326a --- /dev/null +++ b/ai-toolkit/toolkit/network_mixins.py @@ -0,0 +1,892 @@ +import json +import os +from collections import OrderedDict +from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal + +import torch +from optimum.quanto import QTensor +from torch import nn +import weakref + +from tqdm import tqdm + +from toolkit.config_modules import NetworkConfig +from toolkit.lorm import extract_conv, extract_linear, count_parameters +from toolkit.metadata import add_model_hash_to_meta +from toolkit.paths import KEYMAPS_ROOT +from toolkit.saving import get_lora_keymap_from_model_keymap +from optimum.quanto import QBytesTensor + +if TYPE_CHECKING: + from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule + from toolkit.lora_special import LoRASpecialNetwork, LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.models.DoRA import DoRAModule + +Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] +Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule'] + +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear', + 'QLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + +ExtractMode = Union[ + 'existing' + 'fixed', + 'threshold', + 'ratio', + 'quantile', + 'percentage' +] + +printed_messages = [] + + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + + +def broadcast_and_multiply(tensor, multiplier): + # Determine the number of dimensions required + num_extra_dims = tensor.dim() - multiplier.dim() + + # Unsqueezing the tensor to match the dimensionality + for _ in range(num_extra_dims): + multiplier = multiplier.unsqueeze(-1) + + try: + # Multiplying the broadcasted tensor with the output tensor + result = tensor * multiplier + except RuntimeError as e: + print(e) + print(tensor.size()) + print(multiplier.size()) + raise e + + return result + + +def add_bias(tensor, bias): + if bias is None: + return tensor + # add batch dim + bias = bias.unsqueeze(0) + bias = torch.cat([bias] * tensor.size(0), dim=0) + # Determine the number of dimensions required + num_extra_dims = tensor.dim() - bias.dim() + + # Unsqueezing the tensor to match the dimensionality + for _ in range(num_extra_dims): + bias = bias.unsqueeze(-1) + + # we may need to swap -1 for -2 + if bias.size(1) != tensor.size(1): + if len(bias.size()) == 3: + bias = bias.permute(0, 2, 1) + elif len(bias.size()) == 4: + bias = bias.permute(0, 3, 1, 2) + + # Multiplying the broadcasted tensor with the output tensor + try: + result = tensor + bias + except RuntimeError as e: + print(e) + print(tensor.size()) + print(bias.size()) + raise e + + return result + + +class ExtractableModuleMixin: + def extract_weight( + self: Module, + extract_mode: ExtractMode = "existing", + extract_mode_param: Union[int, float] = None, + ): + device = self.lora_down.weight.device + weight_to_extract = self.org_module[0].weight + if extract_mode == "existing": + extract_mode = 'fixed' + extract_mode_param = self.lora_dim + + if isinstance(weight_to_extract, QBytesTensor): + weight_to_extract = weight_to_extract.dequantize() + + weight_to_extract = weight_to_extract.clone().detach().float() + + if self.org_module[0].__class__.__name__ in CONV_MODULES: + # do conv extraction + down_weight, up_weight, new_dim, diff = extract_conv( + weight=weight_to_extract, + mode=extract_mode, + mode_param=extract_mode_param, + device=device + ) + + elif self.org_module[0].__class__.__name__ in LINEAR_MODULES: + # do linear extraction + down_weight, up_weight, new_dim, diff = extract_linear( + weight=weight_to_extract, + mode=extract_mode, + mode_param=extract_mode_param, + device=device, + ) + else: + raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}") + + self.lora_dim = new_dim + + # inject weights into the param + self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach() + self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach() + + # copy bias if we have one and are using them + if self.org_module[0].bias is not None and self.lora_up.bias is not None: + self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach() + + # set up alphas + self.alpha = (self.alpha * 0) + down_weight.shape[0] + self.scale = self.alpha / self.lora_dim + + # assign them + + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + # scaler is a parameter update the value with 1.0 + self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype) + + +class ToolkitModuleMixin: + def __init__( + self: Module, + *args, + network: Network, + **kwargs + ): + self.network_ref: weakref.ref = weakref.ref(network) + self.is_checkpointing = False + self._multiplier: Union[float, list, torch.Tensor] = None + + def _call_forward(self: Module, x): + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return 0.0 # added to original forward + + if hasattr(self, 'lora_mid') and self.lora_mid is not None: + lx = self.lora_mid(self.lora_down(x)) + else: + try: + lx = self.lora_down(x) + except RuntimeError as e: + print(f"Error in {self.__class__.__name__} lora_down") + raise e + + if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): + lx = self.dropout(lx) + # normal dropout + elif self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.rank_dropout > 0 and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + return lx * scale + + def lorm_forward(self: Network, x, *args, **kwargs): + network: Network = self.network_ref() + if not network.is_active: + return self.org_forward(x, *args, **kwargs) + + orig_dtype = x.dtype + + if x.dtype != self.lora_down.weight.dtype: + x = x.to(self.lora_down.weight.dtype) + + if network.lorm_train_mode == 'local': + # we are going to predict input with both and do a loss on them + inputs = x.detach() + with torch.no_grad(): + # get the local prediction + target_pred = self.org_forward(inputs, *args, **kwargs).detach() + with torch.set_grad_enabled(True): + # make a prediction with the lorm + lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True))) + + local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float()) + # backpropr + local_loss.backward() + + network.module_losses.append(local_loss.detach()) + # return the original as we dont want our trainer to affect ones down the line + return target_pred + + else: + x = self.lora_up(self.lora_down(x)) + if x.dtype != orig_dtype: + x = x.to(orig_dtype) + + def forward(self: Module, x, *args, **kwargs): + skip = False + network: Network = self.network_ref() + if network.is_lorm: + # we are doing lorm + return self.lorm_forward(x, *args, **kwargs) + + # skip if not active + if not network.is_active: + skip = True + + # skip if is merged in + if network.is_merged_in: + skip = True + + # skip if multiplier is 0 + if network._multiplier == 0: + skip = True + + if skip: + # network is not active, avoid doing anything + return self.org_forward(x, *args, **kwargs) + + # if self.__class__.__name__ == "DoRAModule": + # # return dora forward + # return self.dora_forward(x, *args, **kwargs) + + if self.__class__.__name__ == "LokrModule": + return self._call_forward(x) + + org_forwarded = self.org_forward(x, *args, **kwargs) + + if isinstance(x, QTensor): + x = x.dequantize() + # always cast to float32 + lora_input = x.to(self.lora_down.weight.dtype) + lora_output = self._call_forward(lora_input) + multiplier = self.network_ref().torch_multiplier + + lora_output_batch_size = lora_output.size(0) + multiplier_batch_size = multiplier.size(0) + if lora_output_batch_size != multiplier_batch_size: + num_interleaves = lora_output_batch_size // multiplier_batch_size + # todo check if this is correct, do we just concat when doing cfg? + multiplier = multiplier.repeat_interleave(num_interleaves) + + scaled_lora_output = broadcast_and_multiply(lora_output, multiplier) + scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype) + + if self.__class__.__name__ == "DoRAModule": + # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417 + # x = dropout(x) + # todo this wont match the dropout applied to the lora + if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): + lx = self.dropout(x) + # normal dropout + elif self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(x, p=self.dropout) + else: + lx = x + lora_weight = self.lora_up.weight @ self.lora_down.weight + # scale it here + # todo handle our batch split scalers for slider training. For now take the mean of them + scale = multiplier.mean() + scaled_lora_weight = lora_weight * scale + scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight).to(org_forwarded.dtype) + + try: + x = org_forwarded + scaled_lora_output + except RuntimeError as e: + print(e) + print(org_forwarded.size()) + print(scaled_lora_output.size()) + raise e + return x + + def enable_gradient_checkpointing(self: Module): + self.is_checkpointing = True + + def disable_gradient_checkpointing(self: Module): + self.is_checkpointing = False + + def _get_base_qtype(self: Module): + # the qtype string the base model was quantized with (so we can re-quantize after merging), or None + network = self.network_ref() + base_ref = getattr(network, 'base_model_ref', None) + base = base_ref() if base_ref is not None else None + return getattr(getattr(base, 'model_config', None), 'qtype', None) + + @torch.no_grad() + def merge_out(self: Module, merge_out_weight=1.0): + # make sure it is positive + merge_out_weight = abs(merge_out_weight) + # merging out is just merging in the negative of the weight + self.merge_in(merge_weight=-merge_out_weight) + + @torch.no_grad() + def merge_in(self: Module, merge_weight=1.0): + if not self.can_merge_in: + return + # get up/down weight + if self.full_rank: + up_weight = None + else: + up_weight = self.lora_up.weight.clone().float() + down_weight = self.lora_down.weight.clone().float() + + # extract weight from org_module + org_sd = self.org_module[0].state_dict() + # todo find a way to merge in weights when doing quantized model + if 'weight._data' in org_sd: + # quantized weight + return + + weight_key = "weight" + from toolkit.util.quantize import is_quantized_tensor + org_weight = self.org_module[0].weight + is_ao_quantized = is_quantized_tensor(org_weight) + orig_dtype = org_weight.dtype + # dequantize torchao weights so the delta can be merged in full precision + weight = (org_weight.dequantize() if is_ao_quantized else org_weight).float() + + multiplier = merge_weight + scale = self.scale + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + weight_device = weight.device + if weight.device != down_weight.device: + weight = weight.to(down_weight.device) + if scale.device != down_weight.device: + scale = scale.to(down_weight.device) + # merge weight + if self.full_rank: + weight = weight + multiplier * down_weight * scale + elif len(weight.size()) == 2: + # linear + weight = weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + multiplier * conved * scale + + # write the merged weight back, re-quantizing if the original was torchao quantized so the + # model stays quantized across continuous merge/reset cycles + if is_ao_quantized: + from toolkit.util.quantize import get_torchao_config, requantize_module_weight + config = get_torchao_config(self._get_base_qtype()) + if config is None: + print_once(f"Warning: merging into quantized layer {getattr(self, 'lora_name', '?')} " + f"without a known qtype; it will be left dequantized") + requantize_module_weight(self.org_module[0], weight.to(weight_device), orig_dtype, config) + else: + org_sd[weight_key] = weight.to(weight_device, orig_dtype) + self.org_module[0].load_state_dict(org_sd) + + def reset_weights(self: Module): + # reset the weights to zero + org_sd = self.state_dict() + for key in org_sd.keys(): + # only reset lora up + if 'lora_up' in key: + org_sd[key] = torch.zeros_like(org_sd[key]) + self.load_state_dict(org_sd) + + def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): + # LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and + # outputs the same. It is basically a LoRA but with the original module removed + + # if a state dict is passed, use those weights instead of extracting + # todo load from state dict + network: Network = self.network_ref() + lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name) + + extract_mode = lorm_config.extract_mode + extract_mode_param = lorm_config.extract_mode_param + parameter_threshold = lorm_config.parameter_threshold + self.extract_weight( + extract_mode=extract_mode, + extract_mode_param=extract_mode_param + ) + + +class ToolkitNetworkMixin: + def __init__( + self: Network, + *args, + train_text_encoder: Optional[bool] = True, + train_unet: Optional[bool] = True, + is_sdxl=False, + is_v2=False, + is_ssd=False, + is_vega=False, + network_config: Optional[NetworkConfig] = None, + is_lorm=False, + **kwargs + ): + self.train_text_encoder = train_text_encoder + self.train_unet = train_unet + self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self.is_sdxl = is_sdxl + self.is_ssd = is_ssd + self.is_vega = is_vega + self.is_v2 = is_v2 + self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega + self.is_merged_in = False + self.is_lorm = is_lorm + self.network_config: NetworkConfig = network_config + self.module_losses: List[torch.Tensor] = [] + self.lorm_train_mode: Literal['local', None] = None + self.can_merge_in = not is_lorm + # will prevent optimizer from loading as it will have double states + self.did_change_weights = False + + def get_keymap(self: Network, force_weight_mapping=False): + use_weight_mapping = False + + if self.is_ssd: + keymap_tail = 'ssd' + use_weight_mapping = True + elif self.is_vega: + keymap_tail = 'vega' + use_weight_mapping = True + elif self.is_sdxl: + keymap_tail = 'sdxl' + elif self.is_v2: + keymap_tail = 'sd2' + else: + keymap_tail = 'sd1' + # todo double check this + # use_weight_mapping = True + + if force_weight_mapping: + use_weight_mapping = True + + # load keymap + keymap_name = f"stable_diffusion_locon_{keymap_tail}.json" + if use_weight_mapping: + keymap_name = f"stable_diffusion_{keymap_tail}.json" + + keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name) + + keymap = None + # check if file exists + if os.path.exists(keymap_path): + with open(keymap_path, 'r') as f: + keymap = json.load(f)['ldm_diffusers_keymap'] + + if use_weight_mapping and keymap is not None: + # get keymap from weights + keymap = get_lora_keymap_from_model_keymap(keymap) + + # upgrade keymaps for DoRA + if self.network_type.lower() == 'dora': + if keymap is not None: + new_keymap = {} + for ldm_key, diffusers_key in keymap.items(): + ldm_key = ldm_key.replace('.alpha', '.magnitude') + # ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down') + # ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up') + + diffusers_key = diffusers_key.replace('.alpha', '.magnitude') + # diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down') + # diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up') + + new_keymap[ldm_key] = diffusers_key + + keymap = new_keymap + + return keymap + + def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16): + keymap = self.get_keymap() + + save_keymap = {} + if keymap is not None: + for ldm_key, diffusers_key in keymap.items(): + # invert them + save_keymap[diffusers_key] = ldm_key + + state_dict = self.state_dict() + save_dict = OrderedDict() + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + save_key = save_keymap[key] if key in save_keymap else key + save_dict[save_key] = v + del state_dict[key] + + if extra_state_dict is not None: + # add extra items to state dict + for key in list(extra_state_dict.keys()): + v = extra_state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + save_dict[key] = v + + if self.peft_format: + # lora_down = lora_A + # lora_up = lora_B + # no alpha + + new_save_dict = {} + for key, value in save_dict.items(): + # lokr needs alpha + if key.endswith('.alpha') and self.network_type.lower() != "lokr": + continue + new_key = key + new_key = new_key.replace('lora_down', 'lora_A') + new_key = new_key.replace('lora_up', 'lora_B') + # replace all $$ with . + new_key = new_key.replace('$$', '.') + new_save_dict[new_key] = value + + save_dict = new_save_dict + + + if self.network_type.lower() == "lokr" and self.use_old_lokr_format: + new_save_dict = {} + for key, value in save_dict.items(): + # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 + new_key = key + new_key = new_key.replace('lora_transformer_', 'lycoris_') + new_save_dict[new_key] = value + + save_dict = new_save_dict + + if self.base_model_ref is not None: + save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict) + return save_dict + + def save_weights( + self: Network, + file, dtype=torch.float16, + metadata=None, + extra_state_dict: Optional[OrderedDict] = None + ): + save_dict = self.get_state_dict(extra_state_dict=extra_state_dict, dtype=dtype) + + if metadata is not None and len(metadata) == 0: + metadata = None + + if metadata is None: + metadata = OrderedDict() + metadata = add_model_hash_to_meta(save_dict, metadata) + # let the model handle the saving + + if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'): + # call the base model save lora method + self.base_model_ref().save_lora(save_dict, file, metadata) + return + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + save_file(save_dict, file, metadata) + else: + torch.save(save_dict, file) + + def load_weights(self: Network, file, force_weight_mapping=False): + # allows us to save and load to and from ldm weights + keymap = self.get_keymap(force_weight_mapping) + keymap = {} if keymap is None else keymap + + if isinstance(file, str): + if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'): + # call the base model load lora method + weights_sd = self.base_model_ref().load_lora(file) + else: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + else: + # probably a state dict + weights_sd = file + + if self.base_model_ref is not None: + weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd) + + load_sd = OrderedDict() + for key, value in weights_sd.items(): + load_key = keymap[key] if key in keymap else key + # replace old double __ with single _ + if self.is_pixart: + load_key = load_key.replace('__', '_') + + if self.peft_format: + # lora_down = lora_A + # lora_up = lora_B + # no alpha + if load_key.endswith('.alpha') and self.network_type.lower() != "lokr": + continue + load_key = load_key.replace('lora_A', 'lora_down') + load_key = load_key.replace('lora_B', 'lora_up') + # replace all . with $$ + load_key = load_key.replace('.', '$$') + load_key = load_key.replace('$$lora_down$$', '.lora_down.') + load_key = load_key.replace('$$lora_up$$', '.lora_up.') + # full weight modules store their delta as `.diff` / `.diff_b` (anchored at the + # end so this is a no-op for any non-full-weight key) + if load_key.endswith('$$diff'): + load_key = load_key[:-len('$$diff')] + '.diff' + elif load_key.endswith('$$diff_b'): + load_key = load_key[:-len('$$diff_b')] + '.diff_b' + + # patch lokr, not sure why we need to but whatever + if self.network_type.lower() == "lokr": + load_key = load_key.replace('$$lokr_w1', '.lokr_w1') + load_key = load_key.replace('$$lokr_w2', '.lokr_w2') + if load_key.endswith('$$alpha'): + load_key = load_key[:-7] + '.alpha' + + if self.network_type.lower() == "lokr": + # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 + load_key = load_key.replace('lycoris_', 'lora_transformer_') + + load_sd[load_key] = value + + # extract extra items from state dict + current_state_dict = self.state_dict() + extra_dict = OrderedDict() + to_delete = [] + for key in list(load_sd.keys()): + if key not in current_state_dict: + extra_dict[key] = load_sd[key] + to_delete.append(key) + elif "lora_down" in key or "lora_up" in key: + # handle expanding/shrinking LoRA (linear only) + if len(load_sd[key].shape) == 2: + load_value = load_sd[key] # from checkpoint + blank_val = current_state_dict[key] # shape we need in the target model + tgt_h, tgt_w = blank_val.shape + src_h, src_w = load_value.shape + + if (src_h, src_w) == (tgt_h, tgt_w): + # shapes already match: keep original + pass + + elif "lora_down" in key and src_h < tgt_h: + print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}") + new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype) + new_val[:src_h, :src_w] = load_value # src_w should already match + load_sd[key] = new_val + self.did_change_weights = True + + elif "lora_up" in key and src_w < tgt_w: + print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}") + new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype) + new_val[:src_h, :src_w] = load_value # src_h should already match + load_sd[key] = new_val + self.did_change_weights = True + + elif "lora_down" in key and src_h > tgt_h: + print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}") + load_sd[key] = load_value[:tgt_h, :tgt_w] + self.did_change_weights = True + + elif "lora_up" in key and src_w > tgt_w: + print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}") + load_sd[key] = load_value[:tgt_h, :tgt_w] + self.did_change_weights = True + + else: + # unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics) + raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}") + + for key in to_delete: + del load_sd[key] + + print(f"Missing keys: {to_delete}") + if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not ( + len(to_delete) == 1 and 'emb_params' in to_delete): + print(" Attempting to load with forced keymap") + return self.load_weights(file, force_weight_mapping=True) + + info = self.load_state_dict(load_sd, False) + if len(extra_dict.keys()) == 0: + extra_dict = None + return extra_dict + + @torch.no_grad() + def _update_torch_multiplier(self: Network): + # builds a tensor for fast usage in the forward pass of the network modules + # without having to set it in every single module every time it changes + multiplier = self._multiplier + # get first module + try: + first_module = self.get_all_modules()[0] + except IndexError: + raise ValueError("There are not any lora modules in this network. Check your config and try again") + + if hasattr(first_module, 'lora_down'): + device = first_module.lora_down.weight.device + dtype = first_module.lora_down.weight.dtype + if hasattr(first_module.lora_down, '_memory_management_device'): + device = first_module.lora_down._memory_management_device + elif hasattr(first_module, 'lokr_w1'): + device = first_module.lokr_w1.device + dtype = first_module.lokr_w1.dtype + if hasattr(first_module.lokr_w1, '_memory_management_device'): + device = first_module.lokr_w1._memory_management_device + elif hasattr(first_module, 'lokr_w1_a'): + device = first_module.lokr_w1_a.device + dtype = first_module.lokr_w1_a.dtype + if hasattr(first_module.lokr_w1_a, '_memory_management_device'): + device = first_module.lokr_w1_a._memory_management_device + elif hasattr(first_module, 'diff'): + # full weight module + device = first_module.diff.device + dtype = first_module.diff.dtype + else: + raise ValueError("Unknown module type") + with torch.no_grad(): + tensor_multiplier = None + if isinstance(multiplier, int) or isinstance(multiplier, float): + tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) + elif isinstance(multiplier, list): + tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype) + elif isinstance(multiplier, torch.Tensor): + tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) + + self.torch_multiplier = tensor_multiplier.clone().detach() + + @property + def multiplier(self) -> Union[float, List[float], List[List[float]]]: + return self._multiplier + + @multiplier.setter + def multiplier(self, value: Union[float, List[float], List[List[float]]]): + # it takes time to update all the multipliers, so we only do it if the value has changed + if self._multiplier == value: + return + # if we are setting a single value but have a list, keep the list if every item is the same as value + self._multiplier = value + self._update_torch_multiplier() + + # called when the context manager is entered + # ie: with network: + def __enter__(self: Network): + self.is_active = True + + def __exit__(self: Network, exc_type, exc_value, tb): + self.is_active = False + + def force_to(self: Network, device, dtype): + self.to(device, dtype) + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + for lora in loras: + lora.to(device, dtype) + + def get_all_modules(self: Network) -> List[Module]: + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + return loras + + def _update_checkpointing(self: Network): + for module in self.get_all_modules(): + if self.is_checkpointing: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() + + def enable_gradient_checkpointing(self: Network): + # not supported + self.is_checkpointing = True + self._update_checkpointing() + + def disable_gradient_checkpointing(self: Network): + # not supported + self.is_checkpointing = False + self._update_checkpointing() + + def reset_weights(self: Network): + for module in self.get_all_modules(): + module.reset_weights() + + def merge_in(self, merge_weight=1.0): + if self.network_type.lower() == 'dora': + return + self.is_merged_in = True + for module in self.get_all_modules(): + module.merge_in(merge_weight) + + def merge_out(self: Network, merge_weight=1.0): + if not self.is_merged_in: + return + self.is_merged_in = False + for module in self.get_all_modules(): + module.merge_out(merge_weight) + + def extract_weight( + self: Network, + extract_mode: ExtractMode = "existing", + extract_mode_param: Union[int, float] = None, + ): + if extract_mode_param is None: + raise ValueError("extract_mode_param must be set") + for module in tqdm(self.get_all_modules(), desc="Extracting weights"): + module.extract_weight( + extract_mode=extract_mode, + extract_mode_param=extract_mode_param + ) + + def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None): + for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"): + module.setup_lorm(state_dict=state_dict) + + def calculate_lorem_parameter_reduction(self): + params_reduced = 0 + for module in self.get_all_modules(): + num_orig_module_params = count_parameters(module.org_module[0]) + num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up) + params_reduced += (num_orig_module_params - num_lorem_params) + + return params_reduced diff --git a/ai-toolkit/toolkit/optimizer.py b/ai-toolkit/toolkit/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..433106bf761f2576f8cc7f974fe82e640b71b6db --- /dev/null +++ b/ai-toolkit/toolkit/optimizer.py @@ -0,0 +1,108 @@ +import torch + + +def get_optimizer( + params, + optimizer_type='adam', + learning_rate=1e-6, + optimizer_params=None +): + if optimizer_params is None: + optimizer_params = {} + lower_type = optimizer_type.lower() + if lower_type.startswith("dadaptation"): + # dadaptation optimizer does not use standard learning rate. 1 is the default value + import dadaptation + print("Using DAdaptAdam optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + if lower_type.endswith('lion'): + optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) + elif lower_type.endswith('adam'): + optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) + elif lower_type == 'dadaptation': + # backwards compatibility + optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params) + # warn user that dadaptation is deprecated + print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") + elif lower_type.startswith("prodigy8bit"): + from toolkit.optimizers.prodigy_8bit import Prodigy8bit + print("Using Prodigy optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + + print(f"Using lr {use_lr}") + # let net be the neural network you want to train + # you can choose weight decay value based on your problem, 0 by default + optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params) + elif lower_type.startswith("prodigy"): + from prodigyopt import Prodigy + + print("Using Prodigy optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + + print(f"Using lr {use_lr}") + # let net be the neural network you want to train + # you can choose weight decay value based on your problem, 0 by default + optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params) + elif lower_type == "adam8": + from toolkit.optimizers.adam8bit import Adam8bit + + optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "adamw8": + from toolkit.optimizers.adam8bit import Adam8bit + + optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params) + elif lower_type.endswith("8bit"): + import bitsandbytes + + if lower_type == "adam8bit": + return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + if lower_type == "ademamix8bit": + return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "adamw8bit": + return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "lion8bit": + return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) + else: + raise ValueError(f'Unknown optimizer type {optimizer_type}') + elif lower_type == 'adam': + optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + elif lower_type == 'adamw': + optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + elif lower_type == 'lion': + try: + from lion_pytorch import Lion + return Lion(params, lr=learning_rate, **optimizer_params) + except ImportError: + raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") + elif lower_type == 'adagrad': + optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'adafactor': + from toolkit.optimizers.adafactor import Adafactor + if 'relative_step' not in optimizer_params: + optimizer_params['relative_step'] = False + if 'scale_parameter' not in optimizer_params: + optimizer_params['scale_parameter'] = False + if 'warmup_init' not in optimizer_params: + optimizer_params['warmup_init'] = False + optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'automagic': + from toolkit.optimizers.automagic import Automagic + optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'automagic2': + from toolkit.optimizers.automagic2 import Automagic2 + optimizer = Automagic2(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'automagic3': + from toolkit.optimizers.automagic3 import Automagic3 + optimizer = Automagic3(params, lr=float(learning_rate), **optimizer_params) + else: + raise ValueError(f'Unknown optimizer type {optimizer_type}') + return optimizer diff --git a/ai-toolkit/toolkit/optimizers/adafactor.py b/ai-toolkit/toolkit/optimizers/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..8897bdc0794e5a56ca7665bbb6b562bc8af52b9f --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/adafactor.py @@ -0,0 +1,363 @@ +import math +from typing import List +import torch +from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation +from optimum.quanto import QBytesTensor +import random + + +class Adafactor(torch.optim.Optimizer): + """ + Adafactor implementation with stochastic rounding accumulation and stochastic rounding on apply. + Modified from transformers Adafactor implementation to support stochastic rounding accumulation and apply. + + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + do_paramiter_swapping=False, + paramiter_swapping_factor=0.1, + stochastic_accumulation=True, + stochastic_rounding=True, + ): + self.stochastic_rounding = stochastic_rounding + if lr is not None and relative_step: + raise ValueError( + "Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError( + "`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + self.base_lrs: List[float] = [ + lr for group in self.param_groups + ] + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + if stochastic_accumulation: + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + self.do_paramiter_swapping = do_paramiter_swapping + self.paramiter_swapping_factor = paramiter_swapping_factor + self._total_paramiter_size = 0 + # count total paramiters + for group in self.param_groups: + for param in group['params']: + self._total_paramiter_size += torch.numel(param) + # pretty print total paramiters with comma seperation + print(f"Total training paramiters: {self._total_paramiter_size:,}") + + # needs to be enabled to count paramiters + if self.do_paramiter_swapping: + self.enable_paramiter_swapping(self.paramiter_swapping_factor) + + + def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): + self.do_paramiter_swapping = True + self.paramiter_swapping_factor = paramiter_swapping_factor + # call it an initial time + self.swap_paramiters() + + def swap_paramiters(self): + all_params = [] + # deactivate all paramiters + for group in self.param_groups: + for param in group['params']: + param.requires_grad_(False) + # remove any grad + param.grad = None + all_params.append(param) + # shuffle all paramiters + random.shuffle(all_params) + + # keep activating paramiters until we are going to go over the target paramiters + target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor) + total_paramiters = 0 + for param in all_params: + total_paramiters += torch.numel(param) + if total_paramiters >= target_paramiters: + break + else: + param.requires_grad_(True) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * \ + param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- + 1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + # adafactor manages its own lr + def get_learning_rates(self): + lrs = [ + self._get_lr(group, self.state[group["params"][0]]) + for group in self.param_groups + if group["params"][0].grad is not None + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None or not p.requires_grad: + continue + + grad = p.grad + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + if grad.is_sparse: + raise RuntimeError( + "Adafactor does not support sparse gradients.") + + # if p has atts _scale then it is quantized. We need to divide the grad by the scale + # if hasattr(p, "_scale"): + # grad = grad / p._scale + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options( + group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( + grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( + grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + + if isinstance(p_data_fp32, QBytesTensor): + p_data_fp32 = p_data_fp32.dequantize() + if p.dtype != torch.float32: + p_data_fp32 = p_data_fp32.clone().float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + eps = group["eps"] + if isinstance(eps, tuple) or isinstance(eps, list): + eps = eps[0] + update = (grad**2) + eps + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad( + exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_( + update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype != torch.float32 and self.stochastic_rounding: + # apply stochastic rounding + copy_stochastic(p, p_data_fp32) + + return loss diff --git a/ai-toolkit/toolkit/optimizers/adam8bit.py b/ai-toolkit/toolkit/optimizers/adam8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fc976bf456c6f93e26776a614e150450e1875e --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/adam8bit.py @@ -0,0 +1,162 @@ +import math +import torch +from torch.optim import Optimizer +from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation + +class Adam8bit(Optimizer): + """ + Implements Adam optimizer with 8-bit state storage and stochastic rounding. + + Arguments: + params (iterable): Iterable of parameters to optimize or dicts defining parameter groups + lr (float): Learning rate (default: 1e-3) + betas (tuple): Coefficients for computing running averages of gradient and its square (default: (0.9, 0.999)) + eps (float): Term added to denominator to improve numerical stability (default: 1e-8) + weight_decay (float): Weight decay coefficient (default: 0) + decouple (bool): Use AdamW style decoupled weight decay (default: True) + """ + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, decouple=True): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + decouple=decouple) + super(Adam8bit, self).__init__(params, defaults) + + self.is_stochastic_rounding_accumulation = False + + # Setup stochastic grad accumulation hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # Copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + # Call pre step + self.step_hook() + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + lr = group['lr'] + decay = group['weight_decay'] + decouple = group['decouple'] + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data.to(torch.float32) + p_fp32 = p.clone().to(torch.float32) + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p_fp32.data, alpha=decay) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + + exp_avg = state['exp_avg'].to(torch.float32) + exp_avg_sq = state['exp_avg_sq'].to(torch.float32) + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p_fp32.data.mul_(1 - lr * decay) + + # Bias correction + step_size = lr / bias_correction1 + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + # Take step + p_fp32.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Update state with stochastic rounding + state['exp_avg'] = Auto8bitTensor(exp_avg) + state['exp_avg_sq'] = Auto8bitTensor(exp_avg_sq) + + # Apply stochastic rounding to parameters + copy_stochastic(p.data, p_fp32.data) + + return loss + + def state_dict(self): + """Returns the state of the optimizer as a dict.""" + state_dict = super().state_dict() + + # Convert Auto8bitTensor objects to regular state dicts + for param_id, param_state in state_dict['state'].items(): + for key, value in param_state.items(): + if isinstance(value, Auto8bitTensor): + param_state[key] = { + '_type': 'Auto8bitTensor', + 'state': value.state_dict() + } + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the optimizer state.""" + # First, load the basic state + super().load_state_dict(state_dict) + + # Then convert any Auto8bitTensor states back to objects + for param_id, param_state in self.state.items(): + for key, value in param_state.items(): + if isinstance(value, dict) and value.get('_type') == 'Auto8bitTensor': + param_state[key] = Auto8bitTensor(value['state']) + diff --git a/ai-toolkit/toolkit/optimizers/automagic.py b/ai-toolkit/toolkit/optimizers/automagic.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a88eff98a819692ffa270cb3fe5f60bfa46075 --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/automagic.py @@ -0,0 +1,423 @@ +from typing import List +import torch +from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation +from optimum.quanto import QBytesTensor +import random + + +class Automagic(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-6, # lr is start lr + min_lr=1e-7, + max_lr=1e-3, + lr_bump=1e-6, # amount to bump the lr when adjusting + eps=(1e-30, 1e-3), + clip_threshold=1.0, + beta2=0.999, + weight_decay=0.0, + do_paramiter_swapping=False, + paramiter_swapping_factor=0.1, + ): + self.lr = lr + if self.lr > 1e-3: + print(f"Warning! Start lr is very high: {self.lr}. Forcing to 1e-6. this does not work like prodigy") + self.lr = 1e-6 + self.min_lr = min_lr + self.max_lr = max_lr + self.lr_bump = lr_bump + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "beta2": beta2, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + self.base_lrs: List[float] = [ + lr for group in self.param_groups + ] + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + self.do_paramiter_swapping = do_paramiter_swapping + self.paramiter_swapping_factor = paramiter_swapping_factor + self._total_paramiter_size = 0 + # count total paramiters + for group in self.param_groups: + for param in group['params']: + self._total_paramiter_size += torch.numel(param) + # pretty print total paramiters with comma seperation + print(f"Total training paramiters: {self._total_paramiter_size:,}") + + # needs to be enabled to count paramiters + if self.do_paramiter_swapping: + self.enable_paramiter_swapping(self.paramiter_swapping_factor) + + def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): + self.do_paramiter_swapping = True + self.paramiter_swapping_factor = paramiter_swapping_factor + # call it an initial time + self.swap_paramiters() + + def swap_paramiters(self): + all_params = [] + # deactivate all paramiters + for group in self.param_groups: + for param in group['params']: + param.requires_grad_(False) + # remove any grad + param.grad = None + all_params.append(param) + # shuffle all paramiters + random.shuffle(all_params) + + # keep activating paramiters until we are going to go over the target paramiters + target_paramiters = int( + self._total_paramiter_size * self.paramiter_swapping_factor) + total_paramiters = 0 + for param in all_params: + total_paramiters += torch.numel(param) + if total_paramiters >= target_paramiters: + break + else: + param.requires_grad_(True) + + @staticmethod + def _get_lr(param_group, param_state): + if 'avg_lr' in param_state: + lr = param_state["avg_lr"] + else: + lr = 0.0 + return lr + + def _get_group_lr(self, group): + group_lrs = [] + for p in group["params"]: + group_lrs.append(self._get_lr(group, self.state[p])) + # return avg + if len(group_lrs) == 0: + return self.lr + return sum(group_lrs) / len(group_lrs) + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- + 1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + # automagic manages its own lr + def get_learning_rates(self): + + lrs = [ + self._get_group_lr(group) + for group in self.param_groups + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + def get_avg_learning_rate(self): + lrs = self.get_learning_rates() + return sum(lrs) / len(lrs) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None or not p.requires_grad: + continue + + grad = p.grad + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + if grad.is_sparse: + raise RuntimeError( + "Automagic does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored = len(grad_shape) >= 2 + # State Initialization + if len(state) == 0: + self.initialize_state(p) + else: + # Check if exp_avg_sq_row and exp_avg_sq_col exist for factored case + if factored: + if "exp_avg_sq_row" not in state or "exp_avg_sq_col" not in state: + state["exp_avg_sq_row"] = torch.zeros(p.shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(p.shape[:-2] + p.shape[-1:]).to(grad) + else: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + # Check if exp_avg_sq exists for non-factored case + else: + if "exp_avg_sq" not in state: + state["exp_avg_sq"] = torch.zeros_like(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + + if isinstance(p_data_fp32, QBytesTensor): + p_data_fp32 = p_data_fp32.dequantize() + if p.dtype != torch.float32: + p_data_fp32 = p_data_fp32.clone().float() + + # Initialize step if it doesn't exist + if "step" not in state: + state["step"] = 0 + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + + # Use fixed beta2 from group instead of decay_rate calculation + beta2 = group["beta2"] + eps = group["eps"] + if isinstance(eps, tuple) or isinstance(eps, list): + eps = eps[0] + update = (grad**2) + eps + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2).add_( + update.mean(dim=-1), alpha=(1.0 - beta2)) + exp_avg_sq_col.mul_(beta2).add_( + update.mean(dim=-2), alpha=(1.0 - beta2)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad( + exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2).add_(update, alpha=(1.0 - beta2)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + + # Ensure state is properly initialized + if 'last_polarity' not in state or 'lr_mask' not in state: + self.initialize_state(p) + + # Get signs of current last update and updates + last_polarity = state['last_polarity'] + current_polarity = (update > 0).to(torch.bool) + sign_agreement = torch.where( + last_polarity == current_polarity, 1, -1) + state['last_polarity'] = current_polarity + + lr_mask = state['lr_mask'].to(torch.float32) + + # Update learning rate mask based on sign agreement + new_lr = torch.where( + sign_agreement > 0, + lr_mask + self.lr_bump, # Increase lr + lr_mask - self.lr_bump # Decrease lr + ) + + # Clip learning rates to bounds + new_lr = torch.clamp( + new_lr, + min=self.min_lr, + max=self.max_lr + ) + + # Apply the learning rate mask to the update + update.mul_(new_lr) + + state['lr_mask'] = Auto8bitTensor(new_lr) + state['avg_lr'] = torch.mean(new_lr) + + if group["weight_decay"] != 0: + # Apply weight decay with per-parameter learning rates + # Instead of using add_ with a tensor alpha (which isn't supported), + # we'll use element-wise multiplication to apply the weight decay + weight_decay_update = p_data_fp32 * (-group["weight_decay"]) * new_lr + p_data_fp32.add_(weight_decay_update) + + p_data_fp32.add_(-update) + + if p.dtype != torch.float32: + # apply stochastic rounding + copy_stochastic(p, p_data_fp32) + + return loss + + def initialize_state(self, p): + state = self.state[p] + state["step"] = 0 + + # store the lr mask + if 'lr_mask' not in state: + state['lr_mask'] = Auto8bitTensor(torch.ones( + p.shape).to(p.device, dtype=torch.float32) * self.lr + ) + state['avg_lr'] = torch.mean( + state['lr_mask'].to(torch.float32)) + if 'last_polarity' not in state: + state['last_polarity'] = torch.zeros( + p.shape, dtype=torch.bool, device=p.device) + + factored = len(p.shape) >= 2 + if factored: + state["exp_avg_sq_row"] = torch.zeros( + p.shape[:-1]).to(p) + state["exp_avg_sq_col"] = torch.zeros( + p.shape[:-2] + p.shape[-1:]).to(p) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + + state["RMS"] = 0 + + # override the state_dict to save the lr_mask + def state_dict(self, *args, **kwargs): + orig_state_dict = super().state_dict(*args, **kwargs) + # convert the state to quantized tensor to scale and quantized + new_sace_state = {} + for p, state in orig_state_dict['state'].items(): + save_state = {k: v for k, v in state.items() if k != 'lr_mask'} + + # Check if lr_mask exists in the state before trying to access it + if 'lr_mask' in state: + save_state['lr_mask'] = state['lr_mask'].state_dict() + + new_sace_state[p] = save_state + + orig_state_dict['state'] = new_sace_state + + return orig_state_dict + + def load_state_dict(self, state_dict, strict=True): + # Validate that the state_dict is from an Automagic optimizer + is_valid_automagic_state = False + + # Check if state_dict has the expected structure + if 'state' in state_dict and isinstance(state_dict['state'], dict): + # Check if at least one state entry has an lr_mask, which is specific to Automagic + for param_id, param_state in state_dict['state'].items(): + if isinstance(param_state, dict) and 'lr_mask' in param_state: + is_valid_automagic_state = True + break + + if not is_valid_automagic_state: + return + + # First, call the parent class's load_state_dict to load the basic optimizer state + # We'll handle the lr_mask separately + state_dict_copy = { + 'state': {}, + 'param_groups': state_dict['param_groups'] + } + + # Copy all state entries except lr_mask + for param_id, param_state in state_dict['state'].items(): + state_dict_copy['state'][param_id] = { + k: v for k, v in param_state.items() if k != 'lr_mask' + } + + # Call parent class load_state_dict with the modified state dict + super().load_state_dict(state_dict_copy) + + # Now handle the lr_mask separately + # We need to map the saved parameters to the current parameters + # This is tricky because the parameter IDs might be different + + # Get all current parameters that require gradients + current_params = [] + for group in self.param_groups: + for p in group['params']: + if p.requires_grad: + current_params.append(p) + + # If the number of parameters doesn't match, we can't reliably map them + if len(current_params) != len(state_dict['param_groups'][0]['params']): + print(f"WARNING: Number of parameters doesn't match between saved state ({len(state_dict['param_groups'][0]['params'])}) " + f"and current model ({len(current_params)}). Learning rate masks may not be correctly loaded.") + + # Map parameters by their position in the param_groups + # This assumes the order of parameters is preserved between saving and loading + saved_param_ids = list(state_dict['state'].keys()) + + for i, current_param in enumerate(current_params): + if i >= len(saved_param_ids): + break + + saved_param_id = saved_param_ids[i] + saved_state = state_dict['state'][saved_param_id] + + # Skip if this saved state doesn't have an lr_mask + if 'lr_mask' not in saved_state: + continue + + # Initialize the state for this parameter if it doesn't exist + if current_param not in self.state: + self.initialize_state(current_param) + + # Get the current state for this parameter + current_state = self.state[current_param] + + # Load the lr_mask from the saved state + saved_lr_mask = saved_state['lr_mask'] + + # Reconstruct the Auto8bitTensor from its state dict + try: + # Make sure the shapes match + if 'quantized' in saved_lr_mask and saved_lr_mask['quantized'].shape == current_param.shape: + current_state['lr_mask'] = Auto8bitTensor(saved_lr_mask) + else: + print(f"WARNING: Shape mismatch for parameter {i}. " + f"Expected {current_param.shape}, got {saved_lr_mask['quantized'].shape if 'quantized' in saved_lr_mask else 'unknown'}. " + f"Initializing new lr_mask.") + # Initialize a new lr_mask + current_state['lr_mask'] = Auto8bitTensor(torch.ones( + current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr + ) + except Exception as e: + print(f"ERROR: Failed to load lr_mask for parameter {i}: {e}") + # Initialize a new lr_mask + current_state['lr_mask'] = Auto8bitTensor(torch.ones( + current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr + ) diff --git a/ai-toolkit/toolkit/optimizers/automagic2.py b/ai-toolkit/toolkit/optimizers/automagic2.py new file mode 100644 index 0000000000000000000000000000000000000000..68c61b646ac01235002af091d19ba3301d1c9002 --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/automagic2.py @@ -0,0 +1,222 @@ +from typing import List +import torch + + +class Automagic2(torch.optim.Optimizer): + """ + Automagic v2. + + A single scalar learning rate is kept per parameter (e.g. one lr for the + full weight matrix of a Linear layer rather than one per element). The lr + is nudged up when the per-element update direction stays consistent with + the previous step and nudged down when it flips, clamped to [min_lr, max_lr]. + + The optimizer step is fused into the backward pass via + ``register_post_accumulate_grad_hook``: each parameter is updated and its + grad freed as soon as autograd finishes accumulating into it. ``.step()`` + therefore does no real work and peak VRAM stays low. + + Second-moment EMA state is stored in ``p.dtype`` (math runs in fp32 when + the state is lower precision). Stochastic rounding is applied only when + writing back to a bf16 parameter. + """ + + def __init__( + self, + params, + lr: float = 1e-6, + min_lr: float = 1e-7, + max_lr: float = 1e-3, + lr_bump: float = 1e-6, + beta2: float = 0.999, + eps: float = 1e-30, + clip_threshold: float = 1.0, + weight_decay: float = 0.0, + agreement_threshold: float = 0.5, + ): + if lr > 1e-3: + print(f"Warning! Start lr {lr} is very high; forcing to 1e-6.") + lr = 1e-6 + defaults = dict( + lr=lr, + min_lr=min_lr, + max_lr=max_lr, + lr_bump=lr_bump, + beta2=beta2, + eps=eps, + clip_threshold=clip_threshold, + weight_decay=weight_decay, + agreement_threshold=agreement_threshold, + ) + super().__init__(params, defaults) + + self._hook_handles = [] + for group in self.param_groups: + for p in group["params"]: + if p.requires_grad: + handle = p.register_post_accumulate_grad_hook( + self._make_backward_hook(group) + ) + self._hook_handles.append(handle) + + total = sum(p.numel() for g in self.param_groups for p in g["params"]) + print(f"Total training paramiters: {total:,}") + + # ------------------------------------------------------------------ utils + + @staticmethod + def _rms(t: torch.Tensor) -> torch.Tensor: + return t.norm(2) / (t.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(row: torch.Tensor, col: torch.Tensor) -> torch.Tensor: + r = (row / row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c = col.unsqueeze(-2).rsqrt() + return torch.mul(r, c) + + def _init_state(self, p: torch.Tensor, group: dict) -> None: + state = self.state[p] + state["step"] = 0 + state["lr"] = torch.full( + (), float(group["lr"]), dtype=torch.float32, device=p.device + ) + state["last_polarity"] = torch.zeros(p.shape, dtype=torch.bool, device=p.device) + if p.dim() >= 2: + state["exp_avg_sq_row"] = torch.zeros( + p.shape[:-1], dtype=p.dtype, device=p.device + ) + state["exp_avg_sq_col"] = torch.zeros( + p.shape[:-2] + p.shape[-1:], dtype=p.dtype, device=p.device + ) + else: + state["exp_avg_sq"] = torch.zeros(p.shape, dtype=p.dtype, device=p.device) + + def _make_backward_hook(self, group): + def _hook(p: torch.Tensor): + self._update_param(p, group) + + return _hook + + # -------------------------------------------------------------- per-param + + @torch.no_grad() + def _update_param(self, p: torch.Tensor, group: dict) -> None: + if p.grad is None: + return + state = self.state[p] + if len(state) == 0: + self._init_state(p, group) + + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Automagic2 does not support sparse gradients.") + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + + beta2 = group["beta2"] + eps = group["eps"] + sq = (grad * grad).add_(eps) + + if p.dim() >= 2: + row_state = state["exp_avg_sq_row"] + col_state = state["exp_avg_sq_col"] + if row_state.dtype == torch.float32: + row, col = row_state, col_state + row.mul_(beta2).add_(sq.mean(dim=-1), alpha=1.0 - beta2) + col.mul_(beta2).add_(sq.mean(dim=-2), alpha=1.0 - beta2) + else: + row = row_state.to(torch.float32) + col = col_state.to(torch.float32) + row.mul_(beta2).add_(sq.mean(dim=-1), alpha=1.0 - beta2) + col.mul_(beta2).add_(sq.mean(dim=-2), alpha=1.0 - beta2) + row_state.copy_(row.to(row_state.dtype)) + col_state.copy_(col.to(col_state.dtype)) + update = self._approx_sq_grad(row, col).mul_(grad) + else: + v_state = state["exp_avg_sq"] + if v_state.dtype == torch.float32: + v = v_state + v.mul_(beta2).add_(sq, alpha=1.0 - beta2) + else: + v = v_state.to(torch.float32) + v.mul_(beta2).add_(sq, alpha=1.0 - beta2) + v_state.copy_(v.to(v_state.dtype)) + update = v.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + + # Per-element sign agreement collapsed to a single bump decision. + # Kept on-device as a 0-D tensor to avoid a CPU<->GPU sync in the hot path. + cur_polarity = update > 0 + last_polarity = state["last_polarity"] + agreement = (cur_polarity == last_polarity).to(torch.float32).mean() + state["last_polarity"] = cur_polarity + + lr_t = state["lr"] + if state["step"] > 0: + direction = (agreement >= group["agreement_threshold"]).to(lr_t.dtype) * 2.0 - 1.0 + lr_t.add_(direction, alpha=group["lr_bump"]).clamp_( + min=group["min_lr"], max=group["max_lr"] + ) + state["step"] += 1 + + update.mul_(lr_t) + wd = group["weight_decay"] + + if p.dtype == torch.bfloat16: + # Single bf16 -> fp32 conversion shared by weight decay and SR. + new_p_fp32 = p.to(torch.float32) + if wd != 0.0: + update.addcmul_(new_p_fp32, lr_t, value=wd) + new_p_fp32.sub_(update) + # Stochastic rounding fp32 -> bf16: add random noise into the lower + # 16 mantissa bits, then truncate. Done in place on new_p_fp32 so + # we don't allocate a separate int32 work buffer. + as_int = new_p_fp32.view(torch.int32) + as_int.add_(torch.randint_like(as_int, 1 << 16)).bitwise_and_(-65536) + p.copy_(new_p_fp32) + else: + if wd != 0.0: + p_fp32 = p if p.dtype == torch.float32 else p.to(torch.float32) + update.addcmul_(p_fp32, lr_t, value=wd) + p.add_(update.to(p.dtype), alpha=-1.0) + + p.grad = None + + # ----------------------------------------------------------- optimizer API + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + return loss + + def get_learning_rates(self) -> List[float]: + out = [] + for group in self.param_groups: + lrs = [ + float(self.state[p]["lr"]) + for p in group["params"] + if p in self.state and "lr" in self.state[p] + ] + out.append(sum(lrs) / len(lrs) if lrs else float(group["lr"])) + return out + + def get_avg_learning_rate(self) -> float: + lrs = self.get_learning_rates() + return sum(lrs) / len(lrs) if lrs else float(self.defaults["lr"]) + + def load_state_dict(self, state_dict): + # Parent casts every fp state tensor to param.dtype; force lr back to fp32 + # so subsequent lr_bump (default 1e-6) isn't rounded away on bf16 weights. + super().load_state_dict(state_dict) + # Constructor args always win over whatever was saved in the checkpoint. + for group in self.param_groups: + for k, v in self.defaults.items(): + group[k] = v + for p in group["params"]: + st = self.state.get(p) + if st is not None and isinstance(st.get("lr"), torch.Tensor): + st["lr"] = st["lr"].to(torch.float32) diff --git a/ai-toolkit/toolkit/optimizers/automagic3.py b/ai-toolkit/toolkit/optimizers/automagic3.py new file mode 100644 index 0000000000000000000000000000000000000000..1c14cac896631311aea41e67a2c6890f5def1a13 --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/automagic3.py @@ -0,0 +1,681 @@ +""" +NOTE: This is experimental and under active development; expect breaking changes and bugs. Feedback welcome. +""" + +from typing import List +import torch + + +class Automagic3(torch.optim.Optimizer): + """ + Automagic v3. + + A single learning rate is kept per param group (typically: one lr for + the whole run). The control principle: the lr RISES while elements hold + a decisively consistent update direction at the current step size, FALLS + while their signs decisively alternate (the overshoot signature: weights + hopping across a minimum flip sign step to step -- shrinking the step is + what makes a trajectory reappear at a finer scale), and HOLDS on + everything in between, which is treated as noise. + + Each element keeps a window of its last H (= ``polarity_history``, + default 4) update sign bits ("is the update positive", 1-bit packed) -- + H/8 bytes per element (half a byte at the default), the only + per-element optimizer state. A short window suffices because verdicts + are pooled across the whole group: millions of voters make weak + common-mode evidence visible long before any single element is + decisive, and the window length is also the controller's reaction lag + and warmup. + + Vote rule (per element) + ----------------------- + Only the two perfectly decisive window states vote; everything else is + noise: + + up all H signs agree +1 * |update| ("step too small") + down all H-1 transitions flip -1 * |update| ("step too large": + (perfect alternation) the overshoot signature) + else any imperfect window 0 (noise) + + The two events are exact mirrors with IDENTICAL pure-noise probability + (2 of the 2^H possible windows each; ~0.8% per element at H=8), so equal + weights balance exactly -- no correction factors, no tiers. Per element + the events are rare, but the verdict is pooled over the whole group + (millions of elements -> tens of thousands of voters per step even + under pure noise, mean zero), so the pooled signal is smooth and a real + trend or real overshoot shifts it decisively. A majority being overshot + always outvotes a persistent minority, which is what anchors the lr's + absolute level without external rails. Weighting by |update| lets the + elements actually moving the weights dominate; an exact-zero update + records as the negative bit, but such dead/masked elements carry zero + weight anyway. A tensor abstains entirely until its window has filled + (the first H steps, and again after a history reset on resume). + + ONE learning rate per param group -- not per tensor. Every element of + every tensor in the group votes into a single pool, and the group lr is + nudged once per step by the pooled result, applied multiplicatively + with NO gain knob: ``lr *= exp(vote)`` -- the lr moves at exactly the + rate the model votes for it. A fully unanimous pool (practically + unreachable) would move e ~= 2.7x per step; the silent majority dilutes + the pooled vote, so realistic moves are a few percent per step, and the + worst-case overshoot past the edge is bounded by the H-step window lag + before alternation votes answer. There is no + noise-floor estimation, no smoothing, no significance test: the polarity + windows are the only indicator. Pooling at group level (rather than per + tensor, and originally rather than per channel) is the load-bearing + choice: COUPLED tensors fight per-tensor lrs exactly like coupled + channels fight per-channel ones. A Q/K pair is the canonical case -- + Q's weights scaling up while K's scale down preserves the attention + logits, so the gradients reward whichever asymmetry randomly seeded + first: Q votes "too slow" and climbs while K votes "too fast" and sinks, + self-reinforcing without bound. One shared lr makes those opposing votes + cancel in the pool instead of diverging, so only common-mode evidence + ("the whole group's step is too small / too large") moves the lr. + + With ``fused=True`` (default) the step is fused into the backward pass via + ``register_post_accumulate_grad_hook``: each parameter is updated and its + grad freed as soon as autograd finishes accumulating into it. ``.step()`` + therefore does no real work and peak VRAM stays low. Note this bypasses the + trainer's grad clipping / nan-skip (they run after backward) and is not + compatible with multi-backward gradient accumulation. + + With ``fused=False`` it behaves like a traditional optimizer: grads + accumulate across backward passes and the update happens in ``.step()``. + Low-precision (bf16/fp16) grads are accumulated with stochastic rounding so + small per-micro-batch grads aren't lost; fp32 grads accumulate normally. + + Second-moment EMA state is stored in ``p.dtype`` (math runs in fp32 when + the state is lower precision). Updates to low-precision (e.g. bf16/fp16) + parameters are applied in fp32 and stochastically rounded on write-back. + + Parameters + ---------- + lr : float + Starting learning rate for every group. The controller adapts away + from this in whichever direction the pooled vote points, so it is a + launch point, not a tuned target. There are no min/max lr clamps + (only a numerical overflow guard far outside the usable range). + beta2 : float + EMA decay for the second moment, as in Adam/Adafactor. + eps : float + Floor added to the second moment before the rsqrt, to avoid div-by-zero. + clip_threshold : float + Trust region on the update: its RMS is scaled to <= this, then every + element is clamped to +/- this, so no single weight takes an outsized + step. + weight_decay : float + Decoupled (AdamW-style) weight decay; 0 disables it. + polarity_history : int + Sign-history window length H (2 to 64, default 4); H/8 bytes of + state per element. Longer windows make the two vote events rarer + and more decisive (probability 2^(1-H) each under noise -- a real + trend's excess grows ~(1+rho)^H), so detection sharpens, at the + cost of memory, an H-step reaction lag/warmup, and fewer voters + per step. Changing it on resume resets the histories cleanly (one + re-warmup of H steps). + fused : bool + If True (default), each param is updated inside the backward pass the + moment its grad is ready -- low peak VRAM, but it bypasses the trainer's + grad clipping / nan-skip and cannot be combined with multi-backward + gradient accumulation. If False, a normal ``.step()``-time update, with + low-precision grads accumulated using stochastic rounding. + + Improvements over v2 + -------------------- + 1. One adaptive lr per param group (v2 had one static lr per tensor). + Plain English: the group finds its learning rate automatically, and no + layer can run away or freeze relative to the others -- which is what + used to split a full finetune into over-cooked and dead layers and + destroy it. (Earlier v3s used a separate lr per output channel, then + per tensor; each level let coupled units -- channels, then Q/K-style + tensor pairs -- fight and split to opposite extremes, so the lr was + pooled one level up each time until the fighting was structurally + impossible.) + + 2. Direction-consistency lr control with a real equilibrium. v2 bumped + the lr from raw single-step agreement, which has no upper fixed point + -- a parameter that is simply still descending keeps agreeing at any + lr, so the lr ratchets up and eventually runs away on long runs. v3 + votes from each element's recent sign window (see the vote rule + above). Plain English: the lr speeds up while the model holds a + trajectory, backs off hard when it overshoots, and holds steady on + pure noise. + + 3. Multiplicative (geometric) lr bump (was additive). v2 added/subtracted a + fixed absolute amount, so the same bump was a huge relative jump when the + lr was tiny and a negligible one when it was large. v3 multiplies by + ``exp(vote)`` -- a fixed *percentage* step. Plain + English: the lr moves at the same relative pace whether it is tiny or + large, traverses its whole range in a predictable number of steps, and a + full up bump is exactly cancelled by a full down bump (no drift); the + gain knob was removed entirely once the vote became a pooled + fraction with natural log-units. + + 4. Stochastic rounding for fp16, not just bf16. v2 only rounded bf16 + write-backs and let fp16 fall back to round-to-nearest, silently + discarding updates smaller than an fp16 ULP. v3 stochastically rounds + both (fast bit-trick for bf16/fp16, generic fallback for other low + precisions). Plain English: fp16 training no longer throws away small + weight updates, so it actually keeps learning instead of stalling. + + 5. Faster hot path, identical math. eps is folded into the small reduced + row/col vectors instead of the full gradient-square tensor; the lr scale + and parameter update are fused into one ``addcmul_``; the per-element + direction and flip sums are recomputed from the 1-bit history planes + in a single batched unpack and two integer reductions, and scored + with three boolean compares and weighted sums. Plain English: each + step issues few GPU passes over the weights, so it runs fast + (notably in bf16/fp16). + """ + + def __init__( + self, + params, + lr: float = 1e-6, + beta2: float = 0.999, + eps: float = 1e-30, + clip_threshold: float = 1.0, + weight_decay: float = 0.0, + polarity_history: int = 8, # sign-history window length (2-64) + fused: bool = True, + ): + if lr > 1e-3: + # No clamping: a too-high start just oscillates immediately and + # the controller drives it down. + print( + f"Note: start lr {lr} is high; the controller will correct it " + f"(the pooled vote will walk it down)." + ) + defaults = dict( + lr=lr, + beta2=beta2, + eps=eps, + clip_threshold=clip_threshold, + weight_decay=weight_decay, + polarity_history=max(2, min(64, int(polarity_history))), + ) + super().__init__(params, defaults) + + self.fused = fused + self._rebuild_group_index() + self._hook_handles = [] + for group in self.param_groups: + for p in group["params"]: + if not p.requires_grad: + continue + if self.fused: + # Fused: update each param the moment its grad is ready. + handle = p.register_post_accumulate_grad_hook( + self._make_backward_hook(group) + ) + self._hook_handles.append(handle) + elif p.dtype != torch.float32: + # Non-fused: the actual update happens in .step(); here we + # only stochastically accumulate low-precision grads across + # micro-batches so repeated round-to-nearest doesn't drop + # small grads (fp32 grads accumulate losslessly on their own). + handle = p.register_post_accumulate_grad_hook( + self._make_accum_hook() + ) + self._hook_handles.append(handle) + + total = sum(p.numel() for g in self.param_groups for p in g["params"]) + print(f"Total training paramiters: {total:,}") + + # ------------------------------------------------------------------ utils + + @staticmethod + def _rms(t: torch.Tensor) -> torch.Tensor: + # Root-mean-square of a tensor; used to size the trust-region clip. + return t.norm(2) / (t.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(row: torch.Tensor, col: torch.Tensor) -> torch.Tensor: + # Adafactor's factored second moment (inherited from v2). Rather than + # store a full RxC tensor of running grad^2, only its per-row and + # per-col means are kept; this rebuilds the rank-1 approximation of + # 1/sqrt(v) -- the per-element update scale -- as the outer product + # rsqrt(row / mean(row)) (x) rsqrt(col). That is the standard HF + # Adafactor reconstruction and is what keeps optimizer state small. + r = (row / row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c = col.unsqueeze(-2).rsqrt() + return torch.mul(r, c) + + @staticmethod + def _sr_truncate(v_fp32: torch.Tensor, drop_bits: int) -> torch.Tensor: + # Fast in-place stochastic rounding for a low-precision float that is a + # mantissa truncation of fp32: add uniform noise into the dropped low + # mantissa bits of the fp32 bit pattern, then zero them, so the + # subsequent narrowing cast is exact and rounds up with probability + # equal to the truncated fractional part. bf16 drops 16 bits (it is the + # high half of fp32); fp16 drops 13 bits (23 - 10 mantissa) and is exact + # within its normal exponent range -- values past fp16's overflow / + # subnormal limits are rounded at fp32 granularity, which trained + # weights effectively never reach. + as_int = v_fp32.view(torch.int32) + as_int.add_(torch.randint_like(as_int, 1 << drop_bits)) + as_int.bitwise_and_(-(1 << drop_bits)) + return v_fp32 + + @staticmethod + def _stochastic_round(v: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + # Generic fp32 -> low-precision stochastic rounding for dtypes that are + # not a mantissa truncation of fp32 (the bf16/fp16 fast path in + # _sr_truncate does not apply). Adds uniform noise of +/- half a target + # ULP and rounds to nearest, so P(round up) equals the fractional + # distance to the next representable value -> unbiased in expectation. + # The ULP at |v| is 2**floor(log2|v|) * eps(dtype). + finfo = torch.finfo(dtype) + absv = v.abs().clamp_(min=finfo.tiny) + ulp = torch.exp2(torch.floor(torch.log2(absv))).mul_(finfo.eps) + noise = torch.rand_like(v).sub_(0.5).mul_(ulp) + return v.add_(noise).to(dtype) + + # Per-device cached constants for pack/unpack (avoid re-allocating a tiny + # tensor on every call). + _PACK_CONSTS: dict = {} + + @classmethod + def _pack_consts(cls, device): + consts = cls._PACK_CONSTS.get(device) + if consts is None: + consts = ( + torch.tensor( + [1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8 + ), + torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7], device=device, dtype=torch.uint8 + ), + ) + cls._PACK_CONSTS[device] = consts + return consts + + @classmethod + def _pack_bits(cls, bits: torch.Tensor) -> torch.Tensor: + # Pack sign bits (bool / {0, 1}) 8 per byte (uint8), as a base-2 dot + # product per group of 8 (two kernels rather than per-slice shift/or + # chains). + weights, _ = cls._pack_consts(bits.device) + flat = bits.reshape(-1).to(torch.uint8) + pad = (-flat.numel()) % 8 + if pad: + flat = torch.cat([flat, flat.new_zeros(pad)]) + return (flat.view(-1, 8) * weights).sum(-1, dtype=torch.uint8) + + @classmethod + def _unpack_bits(cls, packed: torch.Tensor, numel: int) -> torch.Tensor: + # Inverse of _pack_bits: uint8 -> flat uint8 {0, 1} of length numel. + _, shifts = cls._pack_consts(packed.device) + vals = (packed.unsqueeze(-1) >> shifts).bitwise_and_(1) + return vals.view(-1)[:numel] + + def _rebuild_group_index(self) -> None: + # param -> index of its param group, plus per-group vote accumulators + # (weighted vote mass and total weight, gathered across every tensor + # in the group during the step and applied once in .step()). The map + # exists because the fused hooks cannot rely on group-dict identity: + # the parent's load_state_dict replaces the group dicts. + self._param_group_index = { + p: gi for gi, group in enumerate(self.param_groups) for p in group["params"] + } + self._group_num: List = [None] * len(self.param_groups) + self._group_den: List = [None] * len(self.param_groups) + + @classmethod + def _stochastic_copy_(cls, dst: torch.Tensor, src_fp32: torch.Tensor) -> None: + # Stochastically round the fp32 ``src`` into the low-precision ``dst`` in + # place. Uses the fast mantissa-truncation path for bf16/fp16 and the + # generic method otherwise. ``src_fp32`` may be mutated (caller owns it). + if dst.dtype == torch.bfloat16: + dst.copy_(cls._sr_truncate(src_fp32, 16)) + elif dst.dtype == torch.float16: + dst.copy_(cls._sr_truncate(src_fp32, 13)) + else: + dst.copy_(cls._stochastic_round(src_fp32, dst.dtype)) + + def _make_accum_hook(self): + # Non-fused grad accumulation for low-precision params: accumulate the + # running sum in fp32 then stochastically round it back into the + # low-precision ``_accum_grad`` buffer, so small per-micro-batch grads + # are not lost to repeated round-to-nearest. .step() consumes the buffer. + def _hook(p: torch.Tensor): + if p.grad is None: + return + if hasattr(p, "_accum_grad"): + acc = p._accum_grad.to(torch.float32).add_(p.grad.to(torch.float32)) + self._stochastic_copy_(p._accum_grad, acc) + else: + p._accum_grad = p.grad.clone() + p.grad = None + + return _hook + + def _init_state(self, p: torch.Tensor, group: dict) -> None: + state = self.state[p] + state["step"] = 0 + # The group lr, mirrored per param (every param in a group receives + # identical multiplicative nudges, so these stay equal; storing per + # param rides the normal state_dict machinery and tolerates + # multi-device groups). + state["lr"] = torch.tensor( + float(group["lr"]), dtype=torch.float32, device=p.device + ) + # Ring buffer of per-element update sign bits, one 1-bit-packed + # plane per step (H/8 bytes per element). Sums are recomputed from + # the planes each step rather than stored -- the history is the + # ONLY per-element state. + H = group["polarity_history"] + width = (p.numel() + 7) // 8 + state["sign_history"] = torch.zeros( + (H, width), dtype=torch.uint8, device=p.device + ) + # Index of the OLDEST plane (the one overwritten next step). + state["hist_idx"] = 0 + # Number of real sign planes stored so far; the controller is gated + # until the window is full (there is no per-element abstain state). + state["hist_fill"] = 0 + if p.dim() >= 2: + state["exp_avg_sq_row"] = torch.zeros( + p.shape[:-1], dtype=p.dtype, device=p.device + ) + state["exp_avg_sq_col"] = torch.zeros( + p.shape[:-2] + p.shape[-1:], dtype=p.dtype, device=p.device + ) + else: + state["exp_avg_sq"] = torch.zeros(p.shape, dtype=p.dtype, device=p.device) + + def _make_backward_hook(self, group): + def _hook(p: torch.Tensor): + self._update_param(p, group) + + return _hook + + # -------------------------------------------------------------- per-param + + @torch.no_grad() + def _update_param(self, p: torch.Tensor, group: dict) -> None: + if p.grad is None: + return + state = self.state[p] + if len(state) == 0: + self._init_state(p, group) + + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Automagic3 does not support sparse gradients.") + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + + # In fused mode this runs inside backward, so the trainer's grad + # clipping and nan/inf-skip come too late to protect us. A single + # non-finite gradient would poison the second-moment EMA (NaN stays + # NaN forever) and corrupt the weights, so neutralise non-finite + # grads in place (we own this fp32 copy); those elements contribute + # nothing this step. Large but finite grads are left alone -- the + # second-moment normalisation already bounds their effect. + grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) + + beta2 = group["beta2"] + eps = group["eps"] + # eps is folded into the reduced row/col (or rsqrt) instead of being + # added to the full-size sq tensor: mean(g^2 + eps) == mean(g^2) + eps, + # which saves a full-size kernel pass. + sq = grad * grad + + # Second moment: a beta2-EMA of grad^2, then update = grad / sqrt(v), + # exactly as Adam/Adafactor (this magnitude-normalises the step; only the + # *sign* of the result drives the lr controller further down). For >=2D + # params v is Adafactor-factored into row/col means (small state, see + # _approx_sq_grad); 1D params (biases, norms) keep the full per-element + # second moment. State lives in p.dtype; when that is low precision the + # math is done in an fp32 copy and written back. + if p.dim() >= 2: + row_state = state["exp_avg_sq_row"] + col_state = state["exp_avg_sq_col"] + if row_state.dtype == torch.float32: + row, col = row_state, col_state + row.mul_(beta2).add_(sq.mean(dim=-1).add_(eps), alpha=1.0 - beta2) + col.mul_(beta2).add_(sq.mean(dim=-2).add_(eps), alpha=1.0 - beta2) + else: + row = row_state.to(torch.float32) + col = col_state.to(torch.float32) + row.mul_(beta2).add_(sq.mean(dim=-1).add_(eps), alpha=1.0 - beta2) + col.mul_(beta2).add_(sq.mean(dim=-2).add_(eps), alpha=1.0 - beta2) + row_state.copy_(row.to(row_state.dtype)) + col_state.copy_(col.to(col_state.dtype)) + update = self._approx_sq_grad(row, col).mul_(grad) + else: + v_state = state["exp_avg_sq"] + if v_state.dtype == torch.float32: + v = v_state + v.mul_(beta2).add_(sq, alpha=1.0 - beta2) + else: + v = v_state.to(torch.float32) + v.mul_(beta2).add_(sq, alpha=1.0 - beta2) + v_state.copy_(v.to(v_state.dtype)) + update = v.add(eps).rsqrt().mul_(grad) + + # Update-RMS clip (trust region): scale so the update RMS never exceeds + # clip_threshold. No bias-correction warmup -- LoRA runs are short and a + # slow ramp wastes steps; for a soft start the user can set a low start + # lr and let the lr bump up on its own. + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + # The RMS clip only bounds the aggregate, so a single outlier element can + # still survive at ~sqrt(numel)*clip_threshold and hit one weight hard, + # distorting the model. Cap each element to clip_threshold (a true + # max-norm trust region) so no single weight can take an outsized step. + update.clamp_(-group["clip_threshold"], group["clip_threshold"]) + + # Direction-consistency lr control (the vote rule -- see the class + # docstring). The second-moment scale, RMS clip and clamp are all + # positive, so the sign bit is exactly sign(grad); an exact-zero + # update records as the negative bit, harmless because its |update| + # vote weight is zero. + cur_bits = update.gt(0.0) + hist = state["sign_history"] # (H, numel/8) 1-bit packed uint8 + idx = state["hist_idx"] # oldest plane (overwritten below) + H = hist.shape[0] + lr_t = state["lr"] # this param's mirror of the shared group lr + + # Slide the window first so the vote sees the freshest H signs. + hist[idx].copy_(self._pack_bits(cur_bits)) + state["hist_idx"] = (idx + 1) % H + # The planes hold garbage until H real signs have been stored (fresh + # start or a history reset on resume): gate the controller, not the + # parameter update, until the window is full. + fill = min(H, state["hist_fill"] + 1) + state["hist_fill"] = fill + + if fill == H: + # Extremes-only vote (see the class docstring): all H signs + # agreeing votes up, perfect alternation (all H-1 transitions + # flipping) votes down -- the two events have identical + # pure-noise probability (2 of the 2^H windows each), so equal + # +/-1 weights balance exactly. The planes are rolled into + # chronological order so adjacent rows are adjacent steps; XOR + # of neighbour rows marks per-bit flips. The weighted vote mass + # and total weight are ACCUMULATED into this tensor's group; the + # single group lr is nudged once per step in .step(). + _, shifts = self._pack_consts(hist.device) + chron = torch.roll(hist, -state["hist_idx"], dims=0) + bits = ( + (chron.unsqueeze(-1) >> shifts) + .bitwise_and_(1) + .view(H, -1)[:, : update.numel()] + ) + s1 = bits.sum(0, dtype=torch.int16) + flips = (bits[1:] ^ bits[:-1]).sum(0, dtype=torch.int16) + up = s1.eq(H).logical_or_(s1.eq(0)) + down = flips.eq(H - 1) + w = update.abs().view(-1) + num = (w * up).sum().sub_((w * down).sum()) + den = w.sum() + gi = self._param_group_index.get(p) + if gi is not None: + if self._group_num[gi] is None: + self._group_num[gi] = num + self._group_den[gi] = den + else: + acc = self._group_num[gi] + if num.device != acc.device: + num = num.to(acc.device) + den = den.to(acc.device) + acc.add_(num) + self._group_den[gi].add_(den) + + state["step"] += 1 + + wd = group["weight_decay"] + + if p.dtype == torch.float32: + # Decoupled weight decay folded in (update += wd*p), then a single + # fused p -= lr * update (lr is a scalar, broadcasts). + if wd != 0.0: + update.add_(p, alpha=wd) + p.addcmul_(update, lr_t, value=-1.0) + else: + # Low precision: apply the update in fp32 then stochastically round + # back, so tiny updates aren't lost to round-to-nearest. Single + # bf16/fp16 -> fp32 conversion shared by weight decay and rounding. + new_p_fp32 = p.to(torch.float32) + if wd != 0.0: + update.add_(new_p_fp32, alpha=wd) + new_p_fp32.addcmul_(update, lr_t, value=-1.0) + self._stochastic_copy_(p, new_p_fp32) + + p.grad = None + + # ----------------------------------------------------------- optimizer API + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + # Fused mode already updated every param in the backward hook; nothing + # left to do. Non-fused mode does the real work here. + if not self.fused: + for group in self.param_groups: + for p in group["params"]: + if not p.requires_grad: + continue + # Low-precision grads were stochastically accumulated into + # _accum_grad; hand it back as the grad to update from. + accum = getattr(p, "_accum_grad", None) + if accum is not None: + p.grad = accum + del p._accum_grad + if p.grad is None: + continue + self._update_param(p, group) + self._apply_group_votes() + return loss + + def _apply_group_votes(self) -> None: + # ONE lr nudge per group per step, from the pooled vote of every + # element of every tensor in the group (see the class docstring on + # why pooling at group level is load-bearing). Each param's lr tensor + # receives the same multiplicative factor, so they stay identical -- + # effectively a single group lr, stored per param only so it rides + # the normal state_dict machinery. All tensor ops: no GPU sync. + for gi, group in enumerate(self.param_groups): + num = self._group_num[gi] + if num is None: + continue + den = self._group_den[gi] + signal = num.div_(den.clamp_(min=1e-30)).clamp_(-1.0, 1.0) + factor = torch.exp(signal) + for p in group["params"]: + st = self.state.get(p) + if st is None or "lr" not in st: + continue + lr_t = st["lr"] + f = factor if factor.device == lr_t.device else factor.to(lr_t.device) + # Numerical overflow guard only -- NOT a control rail + # (decades outside the usable range). + lr_t.mul_(f).clamp_(min=1e-30, max=1e3) + self._group_num[gi] = None + self._group_den[gi] = None + + def get_learning_rates(self) -> List[float]: + # Reporting helper: the (shared) lr of each param group. + out = [] + for group in self.param_groups: + lrs = [ + self.state[p]["lr"] + for p in group["params"] + if p in self.state and "lr" in self.state[p] + ] + out.append(float(torch.stack(lrs).mean()) if lrs else float(group["lr"])) + return out + + def get_avg_learning_rate(self) -> float: + lrs = self.get_learning_rates() + return sum(lrs) / len(lrs) if lrs else float(self.defaults["lr"]) + + def load_state_dict(self, state_dict): + # Parent casts every fp state tensor to param.dtype; force lr back to fp32 + # so subsequent lr bumps aren't rounded away on bf16 weights. + super().load_state_dict(state_dict) + # Hyperparameters are NOT loaded from the checkpoint: constructor args + # always win, so any setting can be changed mid-run just by passing a + # different value when resuming. Only the adaptive state is restored + # -- the group lr and the sign history (when its geometry still + # matches the current config). + for group in self.param_groups: + for k, v in self.defaults.items(): + group[k] = v + # One lr per group: unify the restored lrs to their geometric + # median (they are already identical for checkpoints from this + # version; older per-tensor checkpoints land on a sane middle). + lrs = [ + st["lr"] + for p in group["params"] + if (st := self.state.get(p)) is not None + and isinstance(st.get("lr"), torch.Tensor) + ] + med = None + if lrs: + dev = lrs[0].device + med = ( + torch.stack([t.to(torch.float32).to(dev) for t in lrs]) + .log_() + .median() + .exp_() + ) + for p in group["params"]: + st = self.state.get(p) + if st is None: + continue + if isinstance(st.get("lr"), torch.Tensor): + st["lr"] = st["lr"].to(torch.float32) + if med is not None: + st["lr"].copy_(med.to(st["lr"].device)) + # Sign history: keep it when its geometry matches the current + # config (the parent cast it to param dtype; recover by shape). + # On any mismatch (e.g. a checkpoint from an older window + # layout) -- start fresh. + numel = p.numel() + H = group["polarity_history"] + width = (numel + 7) // 8 + sh = st.get("sign_history") + hist_ok = ( + isinstance(sh, torch.Tensor) + and sh.shape == (H, width) + and isinstance(st.get("hist_idx"), int) + and 0 <= st["hist_idx"] < H + and isinstance(st.get("hist_fill"), int) + and 0 <= st["hist_fill"] <= H + ) + if hist_ok: + st["sign_history"] = sh.to(torch.uint8) + else: + st["sign_history"] = torch.zeros( + (H, width), dtype=torch.uint8, device=p.device + ) + st["hist_idx"] = 0 + st["hist_fill"] = 0 + # The parent rebuilt the group dicts; remap params to groups and + # reset the vote accumulators. + self._rebuild_group_index() diff --git a/ai-toolkit/toolkit/optimizers/optimizer_utils.py b/ai-toolkit/toolkit/optimizers/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3adae2755f306dd1d07fd4ba1605ac08e755bb13 --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/optimizer_utils.py @@ -0,0 +1,227 @@ +import torch +from torch import Tensor +from typing import Optional +from optimum.quanto import QBytesTensor + + +def compute_scale_for_dtype(tensor, dtype): + """ + Compute appropriate scale for the given tensor and target dtype. + + Args: + tensor: Input tensor to be quantized + dtype: Target dtype for quantization + Returns: + Appropriate scale factor for the quantization + """ + if dtype == torch.int8: + abs_max = torch.max(torch.abs(tensor)) + return abs_max / 127.0 if abs_max > 0 else 1.0 + elif dtype == torch.uint8: + max_val = torch.max(tensor) + min_val = torch.min(tensor) + range_val = max_val - min_val + return range_val / 255.0 if range_val > 0 else 1.0 + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # For float8, we typically want to preserve the magnitude of the values + # while fitting within the representable range of the format + abs_max = torch.max(torch.abs(tensor)) + if dtype == torch.float8_e4m3fn: + # e4m3fn has range [-448, 448] with no infinities + max_representable = 448.0 + else: # torch.float8_e5m2 + # e5m2 has range [-57344, 57344] with infinities + max_representable = 57344.0 + + return abs_max / max_representable if abs_max > 0 else 1.0 + else: + raise ValueError(f"Unsupported dtype for quantization: {dtype}") + +def quantize_tensor(tensor, dtype): + """ + Quantize a floating-point tensor to the target dtype with appropriate scaling. + + Args: + tensor: Input tensor (float) + dtype: Target dtype for quantization + Returns: + quantized_data: Quantized tensor + scale: Scale factor used + """ + scale = compute_scale_for_dtype(tensor, dtype) + + if dtype == torch.int8: + quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype) + elif dtype == torch.uint8: + quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype) + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # For float8, we scale and then cast directly to the target type + # The casting operation will handle the appropriate rounding + scaled_tensor = tensor / scale + quantized_data = scaled_tensor.to(dtype) + else: + raise ValueError(f"Unsupported dtype for quantization: {dtype}") + + return quantized_data, scale + + +def update_parameter(target, result_float): + """ + Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases + with proper rescaling for quantized tensors. + + Args: + target: The parameter to update (either torch.Tensor or QBytesTensor) + result_float: The new values to assign (torch.Tensor) + """ + if isinstance(target, QBytesTensor): + # Get the target dtype from the existing quantized tensor + target_dtype = target._data.dtype + + # Handle device placement + device = target._data.device + result_float = result_float.to(device) + + # Compute new quantized values and scale + quantized_data, new_scale = quantize_tensor(result_float, target_dtype) + + # Update the internal tensors with newly computed values + target._data.copy_(quantized_data) + target._scale.copy_(new_scale) + else: + # Regular tensor update + target.copy_(result_float) + + +def get_format_params(dtype: torch.dtype) -> tuple[int, int]: + """ + Returns (mantissa_bits, total_bits) for each format. + mantissa_bits excludes the implicit leading 1. + """ + if dtype == torch.float32: + return 23, 32 + elif dtype == torch.bfloat16: + return 7, 16 + elif dtype == torch.float16: + return 10, 16 + elif dtype == torch.float8_e4m3fn: + return 3, 8 + elif dtype == torch.float8_e5m2: + return 2, 8 + elif dtype == torch.int8: + return 0, 8 # Int8 doesn't have mantissa bits + else: + raise ValueError(f"Unsupported dtype: {dtype}") + +def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor): + # adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5 + # create a random 16 bit integer + result = torch.randint_like( + source, + dtype=torch.int32, + low=0, + high=(1 << 16), + ) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + +def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None: + with torch.no_grad(): + # assert if target is on cpu, throw error + assert target.device.type != 'cpu', "Target is on cpu!" + assert source.device.type != 'cpu', "Source is on cpu!" + + if target.dtype == torch.float32: + target.copy_(source) + return + if target.dtype == torch.bfloat16: + copy_stochastic_bf16(target, source) + return + + mantissa_bits, _ = get_format_params(target.dtype) + round_factor = 2 ** (23 - mantissa_bits) + + # Add uniform noise for stochastic rounding + noise = torch.rand_like(source, device=source.device) - 0.5 + rounded = torch.round(source * round_factor + noise) + result_float = rounded / round_factor + + # Clamp for float8 + if target.dtype == torch.float8_e4m3fn: + result_float.clamp_(-448.0, 448.0) + elif target.dtype == torch.float8_e5m2: + result_float.clamp_(-57344.0, 57344.0) + + update_parameter(target, result_float) + + +class Auto8bitTensor: + def __init__(self, data: Tensor, *args, **kwargs): + if isinstance(data, dict): # Add constructor from state dict + self._load_from_state_dict(data) + else: + abs_max = data.abs().max().item() + scale = abs_max / 127.0 if abs_max > 0 else 1.0 + + self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8) + self.scale = scale + self.orig_dtype = data.dtype + + def dequantize(self) -> Tensor: + return self.quantized.to(dtype=torch.float32) * self.scale + + def to(self, *args, **kwargs): + # Handle the dtype argument whether it's positional or keyword + dtype = None + if args and isinstance(args[0], torch.dtype): + dtype = args[0] + args = args[1:] + elif 'dtype' in kwargs: + dtype = kwargs['dtype'] + del kwargs['dtype'] + + if dtype is not None: + # First dequantize then convert to requested dtype + return self.dequantize().to(dtype=dtype, *args, **kwargs) + + # If no dtype specified, just pass through to parent + return self.dequantize().to(*args, **kwargs) + + def state_dict(self): + """Returns a dictionary containing the current state of the tensor.""" + return { + 'quantized': self.quantized, + 'scale': self.scale, + 'orig_dtype': self.orig_dtype + } + + def _load_from_state_dict(self, state_dict): + """Loads the tensor state from a state dictionary.""" + self.quantized = state_dict['quantized'] + self.scale = state_dict['scale'] + self.orig_dtype = state_dict['orig_dtype'] + + def __str__(self): + return f"Auto8bitTensor({self.dequantize()})" + + +def stochastic_grad_accummulation(param): + if hasattr(param, "_accum_grad"): + grad_fp32 = param._accum_grad.clone().to(torch.float32) + grad_fp32.add_(param.grad.to(torch.float32)) + copy_stochastic(param._accum_grad, grad_fp32) + del grad_fp32 + del param.grad + else: + param._accum_grad = param.grad.clone() + del param.grad diff --git a/ai-toolkit/toolkit/optimizers/prodigy_8bit.py b/ai-toolkit/toolkit/optimizers/prodigy_8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7f09149583da67d8f4fbaea6051b0b6694e467 --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/prodigy_8bit.py @@ -0,0 +1,286 @@ +import math +import torch +import torch.distributed as dist +from torch.optim import Optimizer +from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation + + +class Prodigy8bit(Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Handles stochastic rounding for various precisions as well as stochastic gradient accumulation. + Stores state in 8bit for memory savings. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__(self, params, lr=1.0, + betas=(0.9, 0.999), beta3=None, + eps=1e-8, weight_decay=0, decouple=True, + use_bias_correction=False, safeguard_warmup=False, + d0=1e-6, d_coef=1.0, growth_rate=float('inf'), + fsdp_in_use=False): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict(lr=lr, betas=betas, beta3=beta3, + eps=eps, weight_decay=weight_decay, + d=d0, d0=d0, d_max=d0, + d_numerator=0.0, d_coef=d_coef, + k=0, growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use) + self.d0 = d0 + super(Prodigy8bit, self).__init__(params, defaults) + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + # call pre step + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group['use_bias_correction'] + beta1, beta2 = group['betas'] + beta3 = group['beta3'] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group['k'] + + d = group['d'] + d_max = group['d_max'] + d_coef = group['d_coef'] + lr = max(group['lr'] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2**(k+1))**0.5) / (1 - beta1**(k+1)) + else: + bias_correction = 1 + + dlr = d*lr*bias_correction + + growth_rate = group['growth_rate'] + decouple = group['decouple'] + fsdp_in_use = group['fsdp_in_use'] + + d_numerator = group['d_numerator'] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + group_lr = group['lr'] + d0 = group['d0'] + safeguard_warmup = group['safeguard_warmup'] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0") + + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data.to(torch.float32) + p_fp32 = p.clone().to(torch.float32) + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p_fp32.data, alpha=decay) + + state = self.state[p] + + # State initialization + if 'step' not in state: + state['step'] = 0 + state['s'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + state['p0'] = Auto8bitTensor(p_fp32.detach().clone()) + # Exponential moving average of gradient values + state['exp_avg'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = Auto8bitTensor( + torch.zeros_like(p_fp32.data).detach()) + + exp_avg = state['exp_avg'].to(torch.float32) + exp_avg_sq = state['exp_avg_sq'].to(torch.float32) + + s = state['s'].to(torch.float32) + p0 = state['p0'].to(torch.float32) + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(), + (p0.data - p_fp32.data).flatten()).item() + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1-beta2)) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + # update state with stochastic rounding + state['exp_avg'] = Auto8bitTensor(exp_avg) + state['exp_avg_sq'] = Auto8bitTensor(exp_avg_sq) + state['s'] = Auto8bitTensor(s) + state['p0'] = Auto8bitTensor(p0) + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group['d0']: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group['d_numerator'] = global_d_numerator + group['d_denom'] = global_d_denom + group['d'] = d + group['d_max'] = d_max + group['d_hat'] = d_hat + + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.to(torch.float32) + p_fp32 = p.clone().to(torch.float32) + + state = self.state[p] + + exp_avg = state['exp_avg'].to(torch.float32) + exp_avg_sq = state['exp_avg_sq'].to(torch.float32) + + state['step'] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p_fp32.data.add_(p_fp32.data, alpha=-decay * dlr) + + # Take step + p_fp32.data.addcdiv_(exp_avg, denom, value=-dlr) + # apply stochastic rounding + copy_stochastic(p.data, p_fp32.data) + + group['k'] = k + 1 + + return loss diff --git a/ai-toolkit/toolkit/optimizers/test_optimizers.py b/ai-toolkit/toolkit/optimizers/test_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..9c48cb311b4c60b7eecf4a7814188d10c61ff38e --- /dev/null +++ b/ai-toolkit/toolkit/optimizers/test_optimizers.py @@ -0,0 +1,204 @@ +""" +Optimizer benchmark on a ~500M parameter transformer, run once per dtype +(float32, float16, bfloat16). + +Compares speed (ms/step) and peak VRAM across: + - AdamW (torch, unfused — traditional Python loop) + - AdamW8bit (bitsandbytes) + - Adafactor + - Automagic v1 + - Automagic v2 (fused-backward) + - Automagic v3 (fused-backward and traditional/unfused) + - Prodigy +""" +import contextlib +import gc +import io +import os +import sys +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Allow running this file directly: `python test_optimizers.py` without setting PYTHONPATH. +# Toolkit imports happen inside main() so they pick this up. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + + +# ---- model --------------------------------------------------------------- + +class TransformerBlock(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int): + super().__init__() + self.n_heads = n_heads + self.d_head = d_model // n_heads + self.ln1 = nn.LayerNorm(d_model) + self.q = nn.Linear(d_model, d_model, bias=False) + self.k = nn.Linear(d_model, d_model, bias=False) + self.v = nn.Linear(d_model, d_model, bias=False) + self.o = nn.Linear(d_model, d_model, bias=False) + self.ln2 = nn.LayerNorm(d_model) + self.ffn_up = nn.Linear(d_model, d_ff, bias=False) + self.ffn_down = nn.Linear(d_ff, d_model, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, S, D = x.shape + h = self.ln1(x) + q = self.q(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) + k = self.k(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) + v = self.v(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) + a = F.scaled_dot_product_attention(q, k, v, is_causal=True) + a = a.transpose(1, 2).contiguous().view(B, S, D) + x = x + self.o(a) + h = self.ln2(x) + x = x + self.ffn_down(F.gelu(self.ffn_up(h))) + return x + + +class Transformer(nn.Module): + def __init__(self, d_model=1024, n_heads=16, n_layers=40, d_ff=4096): + super().__init__() + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers) + ]) + self.norm = nn.LayerNorm(d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for b in self.blocks: + x = b(x) + return self.norm(x) + + +# ---- benchmark ----------------------------------------------------------- + +DEVICE = "cuda" +DTYPES = [torch.float32, torch.float16, torch.bfloat16] +D_MODEL = 1024 +N_HEADS = 16 +N_LAYERS = 40 +D_FF = 4096 +BATCH = 1 +SEQ = 128 +WARMUP = 3 +ITERS = 10 + + +def build_model(dtype): + torch.manual_seed(0) + return Transformer(D_MODEL, N_HEADS, N_LAYERS, D_FF).to(DEVICE, dtype=dtype) + + +def benchmark(results: list, label: str, opt_factory, dtype): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + model = build_model(dtype) + # Some optimizers print on construction; mute that so the final table is clean. + with contextlib.redirect_stdout(io.StringIO()): + opt = opt_factory(model.parameters()) + x = torch.randn(BATCH, SEQ, D_MODEL, device=DEVICE, dtype=dtype) + + print(f" running {label}...", flush=True) + try: + for _ in range(WARMUP): + opt.zero_grad(set_to_none=True) + model(x).sum().backward() + opt.step() + torch.cuda.synchronize() + + t0 = time.perf_counter() + for _ in range(ITERS): + opt.zero_grad(set_to_none=True) + model(x).sum().backward() + opt.step() + torch.cuda.synchronize() + dt = (time.perf_counter() - t0) / ITERS * 1000 + peak = torch.cuda.max_memory_allocated() / 1024**3 + results.append({"label": label, "ms": dt, "peak": peak, "ok": True}) + except torch.cuda.OutOfMemoryError: + results.append({"label": label, "ms": float("inf"), "peak": float("inf"), "ok": False, "note": "OOM"}) + except Exception as e: + # An optimizer may not support a given dtype; record it and keep going. + print(f" {label} failed: {type(e).__name__}: {e}", flush=True) + results.append({"label": label, "ms": float("inf"), "peak": float("inf"), "ok": False, "note": "ERR"}) + finally: + # Fused-backward optimizers register post-accumulate-grad hooks on the + # params; without removing them the hook closures keep this model alive + # past `del`, inflating the *next* fused optimizer's peak. Unregister so + # each run reports its true isolated peak regardless of run order. + for h in getattr(opt, "_hook_handles", []): + h.remove() + del opt, model + gc.collect() + torch.cuda.empty_cache() + + +def print_table(results: list, dtype_name: str): + results = sorted(results, key=lambda r: r["peak"]) + + headers = ["#", "Optimizer", "Peak VRAM", "Time/step"] + rows = [] + for i, r in enumerate(results, 1): + if not r["ok"]: + rows.append([str(i), r["label"], r.get("note", "OOM"), "-"]) + continue + rows.append([str(i), r["label"], f"{r['peak']:.2f} GB", f"{r['ms']:.1f} ms"]) + + widths = [max(len(str(row[c])) for row in [headers] + rows) for c in range(len(headers))] + + def fmt(row, sep=" │ "): + return sep.join(s.ljust(widths[c]) if c == 1 else s.rjust(widths[c]) for c, s in enumerate(row)) + + line_top = "─" * (sum(widths) + 3 * (len(widths) - 1)) + print() + print(f" dtype: {dtype_name}") + print(line_top) + print(fmt(headers)) + print(line_top) + for row in rows: + print(fmt(row)) + print(line_top) + + +def main(): + n_params = sum(p.numel() for p in build_model(torch.float32).parameters()) + print(f"Model: {N_LAYERS} blocks × d_model={D_MODEL} × d_ff={D_FF}") + print(f" {n_params/1e6:.1f}M params") + print(f"Step: batch={BATCH}, seq={SEQ}") + print(f"Timing: {WARMUP} warmup + {ITERS} timed iters") + print(f"Rounds: {', '.join(str(d).replace('torch.', '') for d in DTYPES)}") + + from toolkit.optimizers.automagic import Automagic + from toolkit.optimizers.automagic2 import Automagic2 + from toolkit.optimizers.automagic3 import Automagic3 + from toolkit.optimizers.adafactor import Adafactor + from prodigyopt import Prodigy + import bitsandbytes as bnb + + optimizers = [ + ("AdamW", lambda p: torch.optim.AdamW(p, lr=1e-4, eps=1e-6, foreach=False, fused=False)), + ("AdamW8bit", lambda p: bnb.optim.AdamW8bit(p, lr=1e-4, eps=1e-6)), + ("Adafactor", lambda p: Adafactor(p, lr=1e-4, scale_parameter=False, relative_step=False, warmup_init=False)), + ("Automagic v1", lambda p: Automagic(p, lr=1e-4)), + ("Automagic v2", lambda p: Automagic2(p, lr=1e-4)), + ("Automagic v3 fused", lambda p: Automagic3(p, lr=1e-4, fused=True)), + ("Automagic v3 unfused", lambda p: Automagic3(p, lr=1e-4, fused=False)), + ("Prodigy", lambda p: Prodigy(p, lr=1.0, eps=1e-6)), + ] + + for dtype in DTYPES: + dtype_name = str(dtype).replace("torch.", "") + print(f"\n=== Round: {dtype_name} ===") + results: list = [] + for label, factory in optimizers: + benchmark(results, label, factory, dtype) + print_table(results, dtype_name) + + +if __name__ == "__main__": + main() diff --git a/ai-toolkit/toolkit/orig_configs/sd_xl_refiner.yaml b/ai-toolkit/toolkit/orig_configs/sd_xl_refiner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cab5fe283d77bf86e0f29e99f3ed0d3c7d9c752f --- /dev/null +++ b/ai-toolkit/toolkit/orig_configs/sd_xl_refiner.yaml @@ -0,0 +1,91 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2560 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 384 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 4 + context_dim: [1280, 1280, 1280, 1280] # 1280 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + legacy: False + freeze: True + layer: penultimate + always_return_pooled: True + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: aesthetic_score + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by one + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/ai-toolkit/toolkit/paths.py b/ai-toolkit/toolkit/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..9f6a8cfb4b33551ec6548f2d7a30471a1cf16aa7 --- /dev/null +++ b/ai-toolkit/toolkit/paths.py @@ -0,0 +1,21 @@ +import os + +TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') +KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") +ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs") +DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs") +COMFY_MODELS_PATH = None + +# check if ENV variable is set +if 'MODELS_PATH' in os.environ: + MODELS_PATH = os.environ['MODELS_PATH'] +else: + MODELS_PATH = os.path.join(TOOLKIT_ROOT, "models") + + +def get_path(path): + # we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root + if not os.path.isabs(path): + path = os.path.join(TOOLKIT_ROOT, path) + return path diff --git a/ai-toolkit/toolkit/photomaker.py b/ai-toolkit/toolkit/photomaker.py new file mode 100644 index 0000000000000000000000000000000000000000..8037969507854129cf342d8b3fae7a6d1ff7581e --- /dev/null +++ b/ai-toolkit/toolkit/photomaker.py @@ -0,0 +1,144 @@ +# Merge image encoder and fuse module to create an ID Encoder +# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding + +import torch +import torch.nn as nn +from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection +from transformers.models.clip.configuration_clip import CLIPVisionConfig +from transformers import PretrainedConfig + +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768 +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) + self.layer_norm = nn.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(CLIPVisionModelWithProjection): + def __init__(self, config=None, *model_args, **model_kwargs): + if config is None: + config = CLIPVisionConfig(**VISION_CONFIG_DICT) + super().__init__(config, *model_args, **model_kwargs) + self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[1] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module( + prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection): + def __init__(self, config=None, *model_args, **model_kwargs): + if config is None: + config = CLIPVisionConfig(**VISION_CONFIG_DICT) + super().__init__(config, *model_args, **model_kwargs) + self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) + + def forward(self, id_pixel_values, do_projection2=True, output_full=False): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + # last_hidden_state, 1, 257, 1024 + vision_output = self.vision_model(id_pixel_values, output_hidden_states=True) + shared_id_embeds = vision_output[1] + id_embeds = self.visual_projection(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + + if do_projection2: + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + + if output_full: + return id_embeds, vision_output + return id_embeds + + + +if __name__ == "__main__": + PhotoMakerIDEncoder() \ No newline at end of file diff --git a/ai-toolkit/toolkit/photomaker_pipeline.py b/ai-toolkit/toolkit/photomaker_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d6437b648e91e5d4e70abbcf0995d76dc1b00f81 --- /dev/null +++ b/ai-toolkit/toolkit/photomaker_pipeline.py @@ -0,0 +1,491 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from collections import OrderedDict +import os +import PIL +import numpy as np + +import torch +from torchvision import transforms as T + +from safetensors import safe_open +from huggingface_hub.utils import validate_hf_hub_args +from transformers import CLIPImageProcessor, CLIPTokenizer +from diffusers import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.utils import ( + _get_model_file, + is_transformers_available, + logging, +) + +from .photomaker import PhotoMakerIDEncoder + +PipelineImageInput = Union[ + PIL.Image.Image, + torch.FloatTensor, + List[PIL.Image.Image], + List[torch.FloatTensor], +] + + +class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): + @validate_hf_hub_args + def load_photomaker_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str, + subfolder: str = '', + trigger_word: str = 'img', + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + weight_name (`str`): + The weight name NOT the path to the weight. + + subfolder (`str`, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + trigger_word (`str`, *optional*, defaults to `"img"`): + The trigger word is used to identify the position of class word in the text prompt, + and it is recommended not to set it as a common word. + This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. + """ + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"id_encoder": {}, "lora_weights": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("id_encoder."): + state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) + elif key.startswith("lora_weights."): + state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["id_encoder", "lora_weights"]: + raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") + + self.trigger_word = trigger_word + # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet + print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") + id_encoder = PhotoMakerIDEncoder() + id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) + id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) + self.id_encoder = id_encoder + self.id_image_processor = CLIPImageProcessor() + + # load lora into models + print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") + self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") + + # Add trigger word token + if self.tokenizer is not None: + self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) + + self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) + + def encode_prompt_with_trigger_word( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_id_images: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + class_tokens_mask: Optional[torch.LongTensor] = None, + ): + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Find the token id of the trigger word + image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + input_ids = tokenizer.encode(prompt) # TODO: batch encode + clean_index = 0 + clean_input_ids = [] + class_token_index = [] + # Find out the corrresponding class word token based on the newly added trigger word token + for i, token_id in enumerate(input_ids): + if token_id == image_token_id: + class_token_index.append(clean_index - 1) + else: + clean_input_ids.append(token_id) + clean_index += 1 + + if len(class_token_index) != 1: + raise ValueError( + f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ + Trigger word: {self.trigger_word}, Prompt: {prompt}." + ) + class_token_index = class_token_index[0] + + # Expand the class word token and corresponding mask + class_token = clean_input_ids[class_token_index] + clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \ + clean_input_ids[class_token_index + 1:] + + # Truncation or padding + max_len = tokenizer.model_max_length + if len(clean_input_ids) > max_len: + clean_input_ids = clean_input_ids[:max_len] + else: + clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( + max_len - len(clean_input_ids) + ) + + class_tokens_mask = [True if class_token_index <= i < class_token_index + num_id_images else False \ + for i in range(len(clean_input_ids))] + + clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) + class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) + + prompt_embeds = text_encoder( + clean_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case + + return prompt_embeds, pooled_prompt_embeds, class_tokens_mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + # Added parameters (for PhotoMaker) + input_id_images: PipelineImageInput = None, + start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future + class_tokens_mask: Optional[torch.LongTensor] = None, + prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + Only the parameters introduced by PhotoMaker are discussed here. + For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py + + Args: + input_id_images (`PipelineImageInput`, *optional*): + Input ID Image to work with PhotoMaker. + class_tokens_mask (`torch.LongTensor`, *optional*): + Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. + prompt_embeds_text_only (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + # + if prompt_embeds is not None and class_tokens_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." + ) + # check the input id images + if input_id_images is None: + raise ValueError( + "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." + ) + if not isinstance(input_id_images, list): + input_id_images = [input_id_images] + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + assert do_classifier_free_guidance + + # 3. Encode input prompt + num_id_images = len(input_id_images) + + ( + prompt_embeds, + pooled_prompt_embeds, + class_tokens_mask, + ) = self.encode_prompt_with_trigger_word( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_id_images=num_id_images, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + class_tokens_mask=class_tokens_mask, + ) + + # 4. Encode input prompt without the trigger word for delayed conditioning + prompt_text_only = prompt.replace(" " + self.trigger_word, "") # sensitive to white space + ( + prompt_embeds_text_only, + negative_prompt_embeds, + pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt_text_only, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds_text_only, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds_text_only, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 5. Prepare the input ID images + dtype = next(self.id_encoder.parameters()).dtype + if not isinstance(input_id_images[0], torch.Tensor): + id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values + + id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts + + # 6. Get the update text embedding with the stacked ID embedding + prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # 7. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 8. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if i <= start_merge_step: + current_prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds_text_only], dim=0 + ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=current_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + # if self.watermark is not None: + # image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file diff --git a/ai-toolkit/toolkit/pipelines.py b/ai-toolkit/toolkit/pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccf6b3ff8a8132822dd9f5d2b5427ca0b544cca --- /dev/null +++ b/ai-toolkit/toolkit/pipelines.py @@ -0,0 +1,1770 @@ +import importlib +import inspect +from typing import Union, List, Optional, Dict, Any, Tuple, Callable + +import numpy as np +import torch +from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline, FluxControlPipeline +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from diffusers.utils import is_torch_xla_available +from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser +from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler +from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance +from diffusers.image_processor import PipelineImageInput +from PIL import Image +import torch.nn.functional as F +from torchvision import transforms + + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline): + + def __init__( + self, + vae: 'AutoencoderKL', + text_encoder: 'CLIPTextModel', + text_encoder_2: 'CLIPTextModelWithProjection', + tokenizer: 'CLIPTokenizer', + tokenizer_2: 'CLIPTokenizer', + unet: 'UNet2DConditionModel', + scheduler: 'KarrasDiffusionSchedulers', + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + raise NotImplementedError("This pipeline is not implemented yet") + # self.sampler = None + # scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + # model = ModelWrapper(unet, scheduler.alphas_cumprod) + # if scheduler.config.prediction_type == "v_prediction": + # self.k_diffusion_model = CompVisVDenoiser(model) + # else: + # self.k_diffusion_model = CompVisDenoiser(model) + + def set_scheduler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + use_karras_sigmas: bool = False, + ): + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 5. Prepare sigmas + if use_karras_sigmas: + sigma_min: float = self.k_diffusion_model.sigmas[0].item() + sigma_max: float = self.k_diffusion_model.sigmas[-1].item() + sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) + sigmas = sigmas.to(device) + else: + sigmas = self.scheduler.sigmas + sigmas = sigmas.to(prompt_embeds.dtype) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + # 7. Define model function + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + t = torch.cat([t] * 2) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # noise_pred = self.unet( + # latent_model_input, + # t, + # encoder_hidden_states=prompt_embeds, + # cross_attention_kwargs=cross_attention_kwargs, + # added_cond_kwargs=added_cond_kwargs, + # return_dict=False, + # )[0] + + noise_pred = self.k_diffusion_model( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False,)[0] + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + + # 8. Run k-diffusion solver + sampler_kwargs = {} + # should work without it + noise_sampler_seed = None + + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + +class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): + + def predict_noise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: Tuple[int, int] = (0, 0), + timestep: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # if not predict_noise: + # # call parent + # return super().__call__( + # prompt=prompt, + # prompt_2=prompt_2, + # height=height, + # width=width, + # num_inference_steps=num_inference_steps, + # denoising_end=denoising_end, + # guidance_scale=guidance_scale, + # negative_prompt=negative_prompt, + # negative_prompt_2=negative_prompt_2, + # num_images_per_prompt=num_images_per_prompt, + # eta=eta, + # generator=generator, + # latents=latents, + # prompt_embeds=prompt_embeds, + # negative_prompt_embeds=negative_prompt_embeds, + # pooled_prompt_embeds=pooled_prompt_embeds, + # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + # output_type=output_type, + # return_dict=return_dict, + # callback=callback, + # callback_steps=callback_steps, + # cross_attention_kwargs=cross_attention_kwargs, + # guidance_rescale=guidance_rescale, + # original_size=original_size, + # crops_coords_top_left=crops_coords_top_left, + # target_size=target_size, + # ) + + # 0. Default height and width to unet + height = self.default_sample_size * self.vae_scale_factor + width = self.default_sample_size * self.vae_scale_factor + + original_size = (height, width) + target_size = (height, width) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ).to(device) # TODO DOES NOT CAST ORIGINALLY + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + return noise_pred + + def enable_model_cpu_offload(self, gpu_id=0): + print('Called cpu offload', gpu_id) + # fuck off + pass + + +class CustomStableDiffusionPipeline(StableDiffusionPipeline): + + # replace the call so it matches SDXL call so we can use the same code and also stop early + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # some of the inputs are to keep it compatible with sdx + def predict_noise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: Tuple[int, int] = (0, 0), + timestep: Optional[int] = None, + ): + + # 0. Default height and width to unet + height = self.unet.config.sample_size * self.vae_scale_factor + width = self.unet.config.sample_size * self.vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + return noise_pred + + +class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline): + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + denoising_start: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 8.2 Determine denoising_start + denoising_start_index = 0 + if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1: + discrete_timestep_start = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps))) + + + with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar: + for i, t in enumerate(timesteps, start=denoising_start_index): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + + + +# TODO this is rough. Need to properly stack unconditional +class FluxWithCFGPipeline(FluxPipeline): + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + # bypass the guidance embedding if there is one + bypass_flux_guidance(self.transformer) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if guidance_scale > 1.00001: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if guidance_scale > 1.00001: + # todo combine these + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + restore_flux_guidance(self.transformer) + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + +class FluxAdvancedControlPipeline(FluxControlPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + do_inpainting=False, + num_controls=1, + ): + self.do_inpainting = do_inpainting + self.num_controls = num_controls + super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + control_image_idx: int = 0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + # num_channels_latents = self.transformer.config.in_channels // 8 + num_channels_latents = 128 // 8 + + # pull mask off control image if there is one it is a pil image + mask = None + if control_image is not None and self.do_inpainting and control_image.mode == "RGBA": + control_img_array = np.array(control_image) + mask = control_img_array[:, :, 3:4] + # scale it to 0 - 1 + mask = mask / 255.0 + # control image ideally would be a full image here + control_img_array = control_img_array[:, :, :3] + control_image = Image.fromarray(control_img_array.astype(np.uint8)) + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + num_control_channels = num_channels_latents + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if mask is not None: + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0) + # resize mask to match control image + mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False) + mask = mask.to(device) + # apply the mask to the control image so the inpaint latent area is 0 + # mask is currently 0 for inpaint area and 1 for image area + control_image = control_image * mask + # invert mask so it is 1 for inpaint area and 0 for image area + mask = 1 - mask + control_image = torch.cat([control_image, mask], dim=1) + num_control_channels += 1 + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_control_channels, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + control_image_list = [] + for idx in range(self.num_controls): + if idx == 0 and self.do_inpainting: + ctrl = torch.zeros_like(latents) + # do ones for mask and zeros for image + ctrl = torch.cat([ctrl, torch.ones_like(ctrl[:, :, :4])], dim=2) + control_image_list.append(ctrl) + else: + control_image_list.append(torch.zeros_like(latents)) + + control_image_list[control_image_idx] = control_image + + latent_model_input = torch.cat([latents] + control_image_list, dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + \ No newline at end of file diff --git a/ai-toolkit/toolkit/pixel_shuffle_encoder.py b/ai-toolkit/toolkit/pixel_shuffle_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8848e405081531920880657362c48858cb7de509 --- /dev/null +++ b/ai-toolkit/toolkit/pixel_shuffle_encoder.py @@ -0,0 +1,211 @@ +from diffusers import AutoencoderKL +from typing import Optional, Union +import torch +import torch.nn as nn +import numpy as np +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.autoencoders.vae import DecoderOutput + + +class PixelMixer(nn.Module): + def __init__(self, in_channels, downscale_factor): + super(PixelMixer, self).__init__() + self.downscale_factor = downscale_factor + self.in_channels = in_channels + + def forward(self, x): + latent = self.encode(x) + out = self.decode(latent) + return out + + def encode(self, x): + return torch.nn.PixelUnshuffle(self.downscale_factor)(x) + + def decode(self, x): + return torch.nn.PixelShuffle(self.downscale_factor)(x) + + +# for reference + +# none of this matters with llvae, but we need to match the interface (latent_channels might matter) + +class Config: + in_channels = 3 + out_channels = 3 + down_block_types = ('1', '1', + '1', '1') + up_block_types = ('1', '1', + '1', '1') + block_out_channels = (1, 1, 1, 1) + latent_channels = 192 # usually 4 + norm_num_groups = 32 + sample_size = 512 + # scaling_factor = 1 + # shift_factor = 0 + scaling_factor = 1.8 + shift_factor = -0.123 + # VAE + # - Mean: -0.12306906282901764 + # - Std: 0.556016206741333 + # Normalization parameters: + # - Shift factor: -0.12306906282901764 + # - Scaling factor: 1.7985087266803625 + + def __getitem__(cls, x): + return getattr(cls, x) + + +class AutoencoderPixelMixer(nn.Module): + + def __init__(self, in_channels=3, downscale_factor=8): + super().__init__() + self.mixer = PixelMixer(in_channels, downscale_factor) + self._dtype = torch.float32 + self._device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.config = Config() + + if downscale_factor == 8: + # we go by len of block out channels in code, so simulate it + self.config.block_out_channels = (1, 1, 1, 1) + self.config.latent_channels = 192 + + elif downscale_factor == 16: + # we go by len of block out channels in code, so simulate it + self.config.block_out_channels = (1, 1, 1, 1, 1) + self.config.latent_channels = 768 + else: + raise ValueError( + f"downscale_factor {downscale_factor} not supported") + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + # mimic to from torch + def to(self, *args, **kwargs): + # pull out dtype and device if they exist + if 'dtype' in kwargs: + self._dtype = kwargs['dtype'] + if 'device' in kwargs: + self._device = kwargs['device'] + return super().to(*args, **kwargs) + + def enable_xformers_memory_efficient_attention(self): + pass + + # @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + + h = self.mixer.encode(x) + + # moments = self.quant_conv(h) + # posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (h,) + + class FakeDist: + def __init__(self, x): + self._sample = x + + def sample(self): + return self._sample + + return AutoencoderKLOutput(latent_dist=FakeDist(h)) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + dec = self.mixer.decode(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def _set_gradient_checkpointing(self, module, value=False): + pass + + def enable_tiling(self, use_tiling: bool = True): + pass + + def disable_tiling(self): + pass + + def enable_slicing(self): + pass + + def disable_slicing(self): + pass + + def set_use_memory_efficient_attention_xformers(self, value: bool = True): + pass + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +# test it +if __name__ == '__main__': + import os + from PIL import Image + import torchvision.transforms as transforms + user_path = os.path.expanduser('~') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + input_path = os.path.join(user_path, "Pictures/test/test.jpg") + output_path = os.path.join(user_path, "Pictures/test/test.jpg") + img = Image.open(input_path) + img_tensor = transforms.ToTensor()(img) + img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) + print("input_shape: ", list(img_tensor.shape)) + vae = PixelMixer(in_channels=3, downscale_factor=8) + latent = vae.encode(img_tensor) + print("latent_shape: ", list(latent.shape)) + out_tensor = vae.decode(latent) + print("out_shape: ", list(out_tensor.shape)) + + mse_loss = nn.MSELoss() + mse = mse_loss(img_tensor, out_tensor) + print("roundtrip_loss: ", mse.item()) + out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) + out_img.save(output_path) diff --git a/ai-toolkit/toolkit/print.py b/ai-toolkit/toolkit/print.py new file mode 100644 index 0000000000000000000000000000000000000000..4d26d2436fd529c24ad732fdbcbcbb18d467d6c0 --- /dev/null +++ b/ai-toolkit/toolkit/print.py @@ -0,0 +1,34 @@ +import sys +import os +from toolkit.accelerator import get_accelerator + + +def print_acc(*args, **kwargs): + if get_accelerator().is_local_main_process: + print(*args, **kwargs) + + +class Logger: + def __init__(self, filename): + self.terminal = sys.stdout + self.log = open(filename, 'a') + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + self.log.flush() # Make sure it's written immediately + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return self.terminal.isatty() + + +def setup_log_to_file(filename): + if get_accelerator().is_local_main_process: + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename)) + sys.stdout = Logger(filename) + sys.stderr = Logger(filename) diff --git a/ai-toolkit/toolkit/progress_bar.py b/ai-toolkit/toolkit/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..e42f8086a7d29016beea66b09e8c0fdc574c5422 --- /dev/null +++ b/ai-toolkit/toolkit/progress_bar.py @@ -0,0 +1,25 @@ +from tqdm import tqdm +import time + + +class ToolkitProgressBar(tqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.paused = False + self.last_time = self._time() + + def pause(self): + if not self.paused: + self.paused = True + self.last_time = self._time() + + def unpause(self): + if self.paused: + self.paused = False + cur_t = self._time() + self.start_t += cur_t - self.last_time + self.last_print_t = cur_t + + def update(self, *args, **kwargs): + if not self.paused: + super().update(*args, **kwargs) diff --git a/ai-toolkit/toolkit/prompt_utils.py b/ai-toolkit/toolkit/prompt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6628fd611eed267fbc9ec62b3c3a1e2b4ac7640b --- /dev/null +++ b/ai-toolkit/toolkit/prompt_utils.py @@ -0,0 +1,744 @@ +import os +from typing import Optional, TYPE_CHECKING, List, Union, Tuple + +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import random + +from toolkit.train_tools import get_torch_dtype +import itertools +from safetensors import safe_open +from toolkit.advanced_prompt_embeds import AdvancedPromptEmbeds + +if TYPE_CHECKING: + from toolkit.config_modules import SliderTargetConfig + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 + + +class PromptEmbeds: + # text_embeds: torch.Tensor + # pooled_embeds: Union[torch.Tensor, None] + # attention_mask: Union[torch.Tensor, List[torch.Tensor], None] + + def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None: + if isinstance(args, list) or isinstance(args, tuple): + # xl + self.text_embeds = args[0] + self.pooled_embeds = args[1] + else: + # sdv1.x, sdv2.x + self.text_embeds = args + self.pooled_embeds = None + + self.attention_mask = attention_mask + + def to(self, *args, **kwargs): + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + self.text_embeds = [t.to(*args, **kwargs) for t in self.text_embeds] + else: + self.text_embeds = self.text_embeds.to(*args, **kwargs) + if self.pooled_embeds is not None: + self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) + if self.attention_mask is not None: + if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): + self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask] + else: + self.attention_mask = self.attention_mask.to(*args, **kwargs) + return self + + def detach(self): + new_embeds = self.clone() + if isinstance(new_embeds.text_embeds, list) or isinstance(new_embeds.text_embeds, tuple): + new_embeds.text_embeds = [t.detach() for t in new_embeds.text_embeds] + else: + new_embeds.text_embeds = new_embeds.text_embeds.detach() + if new_embeds.pooled_embeds is not None: + new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() + if new_embeds.attention_mask is not None: + if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple): + new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask] + else: + new_embeds.attention_mask = new_embeds.attention_mask.detach() + return new_embeds + + def clone(self): + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + cloned_text_embeds = [t.clone() for t in self.text_embeds] + else: + cloned_text_embeds = self.text_embeds.clone() + if self.pooled_embeds is not None: + prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()]) + else: + if isinstance(cloned_text_embeds, list) or isinstance(cloned_text_embeds, tuple): + prompt_embeds = PromptEmbeds([cloned_text_embeds, None]) + else: + prompt_embeds = PromptEmbeds(cloned_text_embeds) + + if self.attention_mask is not None: + if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): + prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask] + else: + prompt_embeds.attention_mask = self.attention_mask.clone() + return prompt_embeds + + def expand_to_batch(self, batch_size): + pe = self.clone() + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + if len(pe.text_embeds[0].shape) == 2: + current_batch_size = len(pe.text_embeds) + else: + current_batch_size = pe.text_embeds[0].shape[0] + else: + current_batch_size = pe.text_embeds.shape[0] + if current_batch_size == batch_size: + return pe + if current_batch_size != 1: + raise Exception("Can only expand batch size for batch size 1") + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + if len(pe.text_embeds[0].shape) == 2: + # batch is a list of tensors + pe.text_embeds = pe.text_embeds * batch_size + else: + pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds] + else: + pe.text_embeds = pe.text_embeds.expand(batch_size, -1) + if pe.pooled_embeds is not None: + pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) + if pe.attention_mask is not None: + if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple): + pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask] + else: + pe.attention_mask = pe.attention_mask.expand(batch_size, -1) + return pe + + def save(self, path: str): + """ + Save the prompt embeds to a file. + :param path: The path to save the prompt embeds. + """ + pe = self.clone() + state_dict = {} + if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple): + for i, text_embed in enumerate(pe.text_embeds): + state_dict[f"text_embed_{i}"] = text_embed.cpu() + else: + state_dict["text_embed"] = pe.text_embeds.cpu() + + if pe.pooled_embeds is not None: + state_dict["pooled_embed"] = pe.pooled_embeds.cpu() + if pe.attention_mask is not None: + if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple): + for i, attn in enumerate(pe.attention_mask): + state_dict[f"attention_mask_{i}"] = attn.cpu() + else: + state_dict["attention_mask"] = pe.attention_mask.cpu() + os.makedirs(os.path.dirname(path), exist_ok=True) + save_file(state_dict, path) + + @classmethod + def load(cls, path: str) -> 'PromptEmbeds': + """ + Load the prompt embeds from a file. + :param path: The path to load the prompt embeds from. + :return: An instance of PromptEmbeds. + """ + # first check if it is advanced prompt embed file + f = safe_open(path, framework='pt') + metadata = f.metadata() + if metadata is not None and metadata.get("class_name", "") == "AdvancedPromptEmbeds": + return AdvancedPromptEmbeds.load(path=path) + + state_dict = load_file(path, device='cpu') + text_embeds = [] + pooled_embeds = None + attention_mask = [] + is_list = False + for key in sorted(state_dict.keys()): + if key.startswith("text_embed_"): + is_list = True + text_embeds.append(state_dict[key]) + elif key == "text_embed": + text_embeds.append(state_dict[key]) + elif key == "pooled_embed": + pooled_embeds = state_dict[key] + elif key.startswith("attention_mask_"): + attention_mask.append(state_dict[key]) + elif key == "attention_mask": + attention_mask.append(state_dict[key]) + pe = cls(None) + pe.text_embeds = text_embeds + if len(text_embeds) == 1 and not is_list: + pe.text_embeds = text_embeds[0] + if pooled_embeds is not None: + pe.pooled_embeds = pooled_embeds + if len(attention_mask) > 0: + if len(attention_mask) == 1: + pe.attention_mask = attention_mask[0] + else: + pe.attention_mask = attention_mask + return pe + + + +class EncodedPromptPair: + def __init__( + self, + target_class, + target_class_with_neutral, + positive_target, + positive_target_with_neutral, + negative_target, + negative_target_with_neutral, + neutral, + empty_prompt, + both_targets, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + action_list=None, + multiplier=1.0, + multiplier_list=None, + weight=1.0, + target: 'SliderTargetConfig' = None, + ): + self.target_class: PromptEmbeds = target_class + self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral + self.positive_target: PromptEmbeds = positive_target + self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral + self.negative_target: PromptEmbeds = negative_target + self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral + self.neutral: PromptEmbeds = neutral + self.empty_prompt: PromptEmbeds = empty_prompt + self.both_targets: PromptEmbeds = both_targets + self.multiplier: float = multiplier + self.target: 'SliderTargetConfig' = target + if multiplier_list is not None: + self.multiplier_list: list[float] = multiplier_list + else: + self.multiplier_list: list[float] = [multiplier] + self.action: int = action + if action_list is not None: + self.action_list: list[int] = action_list + else: + self.action_list: list[int] = [action] + self.weight: float = weight + + # simulate torch to for tensors + def to(self, *args, **kwargs): + self.target_class = self.target_class.to(*args, **kwargs) + self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs) + self.positive_target = self.positive_target.to(*args, **kwargs) + self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) + self.negative_target = self.negative_target.to(*args, **kwargs) + self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs) + self.neutral = self.neutral.to(*args, **kwargs) + self.empty_prompt = self.empty_prompt.to(*args, **kwargs) + self.both_targets = self.both_targets.to(*args, **kwargs) + return self + + def detach(self): + self.target_class = self.target_class.detach() + self.target_class_with_neutral = self.target_class_with_neutral.detach() + self.positive_target = self.positive_target.detach() + self.positive_target_with_neutral = self.positive_target_with_neutral.detach() + self.negative_target = self.negative_target.detach() + self.negative_target_with_neutral = self.negative_target_with_neutral.detach() + self.neutral = self.neutral.detach() + self.empty_prompt = self.empty_prompt.detach() + self.both_targets = self.both_targets.detach() + return self + + +def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"], padding_side: str = "right") -> PromptEmbeds: + # check if first item has a classmethod of concat_prompt_embeds + if hasattr(prompt_embeds[0].__class__, "concat_prompt_embeds"): + return prompt_embeds[0].__class__.concat_prompt_embeds(prompt_embeds, padding_side=padding_side) + # --- pad text_embeds --- + if isinstance(prompt_embeds[0].text_embeds, (list, tuple)): + text_embeds = [] + for p in prompt_embeds: + text_embeds += p.text_embeds + else: + max_len = max(p.text_embeds.shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + t = p.text_embeds + if t.shape[1] < max_len: + pad = torch.zeros( + (t.shape[0], max_len - t.shape[1], *t.shape[2:]), + dtype=t.dtype, + device=t.device, + ) + if padding_side == "right": + t = torch.cat([t, pad], dim=1) + else: + t = torch.cat([pad, t], dim=1) + padded.append(t) + text_embeds = torch.cat(padded, dim=0) + + # --- pooled embeds --- + pooled_embeds = None + if prompt_embeds[0].pooled_embeds is not None: + pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) + + # --- attention mask --- + attention_mask = None + if prompt_embeds[0].attention_mask is not None: + max_len = max(p.attention_mask.shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + m = p.attention_mask + if m.shape[1] < max_len: + pad = torch.zeros( + (m.shape[0], max_len - m.shape[1]), + dtype=m.dtype, + device=m.device, + ) + if padding_side == "right": + m = torch.cat([m, pad], dim=1) + else: + m = torch.cat([pad, m], dim=1) + padded.append(m) + attention_mask = torch.cat(padded, dim=0) + + # wrap back into PromptEmbeds + pe = PromptEmbeds([text_embeds, pooled_embeds]) + pe.attention_mask = attention_mask + return pe + + +def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]): + weight = prompt_pairs[0].weight + target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs]) + target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs]) + positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs]) + positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs]) + negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs]) + negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs]) + neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs]) + empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs]) + both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs]) + # combine all the lists + action_list = [] + multiplier_list = [] + weight_list = [] + for p in prompt_pairs: + action_list += p.action_list + multiplier_list += p.multiplier_list + return EncodedPromptPair( + target_class=target_class, + target_class_with_neutral=target_class_with_neutral, + positive_target=positive_target, + positive_target_with_neutral=positive_target_with_neutral, + negative_target=negative_target, + negative_target_with_neutral=negative_target_with_neutral, + neutral=neutral, + empty_prompt=empty_prompt, + both_targets=both_targets, + action_list=action_list, + multiplier_list=multiplier_list, + weight=weight, + target=prompt_pairs[0].target + ) + + +def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]: + if hasattr(concatenated.__class__, "split_prompt_embeds"): + return concatenated.__class__.split_prompt_embeds(concatenated, num_parts=num_parts) + if num_parts is None: + # use batch size + num_parts = concatenated.text_embeds.shape[0] + + if isinstance(concatenated.text_embeds, list) or isinstance(concatenated.text_embeds, tuple): + # split each part + text_embeds_splits = [ + torch.chunk(text, num_parts, dim=0) + for text in concatenated.text_embeds + ] + text_embeds_splits = list(zip(*text_embeds_splits)) + else: + text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) + + if concatenated.pooled_embeds is not None: + pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0) + else: + pooled_embeds_splits = [None] * num_parts + + prompt_embeds_list = [ + PromptEmbeds([text, pooled]) + for text, pooled in zip(text_embeds_splits, pooled_embeds_splits) + ] + + return prompt_embeds_list + + +def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]: + target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds) + target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds) + positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds) + positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds) + negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds) + negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds) + neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds) + empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds) + both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds) + + prompt_pairs = [] + for i in range(len(target_class_splits)): + action_list_split = concatenated.action_list[i::len(target_class_splits)] + multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)] + + prompt_pair = EncodedPromptPair( + target_class=target_class_splits[i], + target_class_with_neutral=target_class_with_neutral_splits[i], + positive_target=positive_target_splits[i], + positive_target_with_neutral=positive_target_with_neutral_splits[i], + negative_target=negative_target_splits[i], + negative_target_with_neutral=negative_target_with_neutral_splits[i], + neutral=neutral_splits[i], + empty_prompt=empty_prompt_splits[i], + both_targets=both_targets_splits[i], + action_list=action_list_split, + multiplier_list=multiplier_list_split, + weight=concatenated.weight, + target=concatenated.target + ) + prompt_pairs.append(prompt_pair) + + return prompt_pairs + + +class PromptEmbedsCache: + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0, + multiplier_list=None + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + if multiplier_list is not None: + self.multiplier_list: list[float] = multiplier_list + else: + self.multiplier_list: list[float] = [multiplier] + + def to(self, *args, **kwargs): + self.prompt = self.prompt.to(*args, **kwargs) + self.neg_prompt = self.neg_prompt.to(*args, **kwargs) + return self + + +def concat_anchors(anchors: list[EncodedAnchor]): + prompt = concat_prompt_embeds([a.prompt for a in anchors]) + neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors]) + return EncodedAnchor( + prompt=prompt, + neg_prompt=neg_prompt, + multiplier_list=[a.multiplier for a in anchors] + ) + + +def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]: + prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors) + neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors) + multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors) + + anchors = [] + for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits): + anchor = EncodedAnchor( + prompt=prompt, + neg_prompt=neg_prompt, + multiplier=multiplier.tolist() + ) + anchors.append(anchor) + + return anchors + + +def get_permutations(s, max_permutations=8): + # Split the string by comma + phrases = [phrase.strip() for phrase in s.split(',')] + + # remove empty strings + phrases = [phrase for phrase in phrases if len(phrase) > 0] + # shuffle the list + random.shuffle(phrases) + + # Get all permutations + permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)]) + + # Convert the tuples back to comma separated strings + return [', '.join(permutation) for permutation in permutations] + + +def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']: + from toolkit.config_modules import SliderTargetConfig + pos_permutations = get_permutations(target.positive, max_permutations=max_permutations) + neg_permutations = get_permutations(target.negative, max_permutations=max_permutations) + + permutations = [] + for pos, neg in itertools.product(pos_permutations, neg_permutations): + permutations.append( + SliderTargetConfig( + target_class=target.target_class, + positive=pos, + negative=neg, + multiplier=target.multiplier, + weight=target.weight + ) + ) + + # shuffle the list + random.shuffle(permutations) + + if len(permutations) > max_permutations: + permutations = permutations[:max_permutations] + + return permutations + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +@torch.no_grad() +def encode_prompts_to_cache( + prompt_list: list[str], + sd: "StableDiffusion", + cache: Optional[PromptEmbedsCache] = None, + prompt_tensor_file: Optional[str] = None, +) -> PromptEmbedsCache: + # TODO: add support for larger prompts + if cache is None: + cache = PromptEmbedsCache() + + if prompt_tensor_file is not None: + # check to see if it exists + if os.path.exists(prompt_tensor_file): + # load it. + print(f"Loading prompt tensors from {prompt_tensor_file}") + prompt_tensors = load_file(prompt_tensor_file, device='cpu') + # add them to the cache + for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False): + if prompt_txt.startswith("te:"): + prompt = prompt_txt[3:] + # text_embeds + text_embeds = prompt_tensor + pooled_embeds = None + # find pool embeds + if f"pe:{prompt}" in prompt_tensors: + pooled_embeds = prompt_tensors[f"pe:{prompt}"] + + # make it + prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) + cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) + + if len(cache.prompts) == 0: + print("Prompt tensors not found. Encoding prompts..") + empty_prompt = "" + # encode empty_prompt + cache[empty_prompt] = sd.encode_prompt(empty_prompt) + + for p in tqdm(prompt_list, desc="Encoding prompts", leave=False): + # build the cache + if cache[p] is None: + cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16) + + # should we shard? It can get large + if prompt_tensor_file: + print(f"Saving prompt tensors to {prompt_tensor_file}") + state_dict = {} + for prompt_txt, prompt_embeds in cache.prompts.items(): + state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to( + "cpu", dtype=get_torch_dtype('fp16') + ) + if prompt_embeds.pooled_embeds is not None: + state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to( + "cpu", + dtype=get_torch_dtype('fp16') + ) + save_file(state_dict, prompt_tensor_file) + + return cache + + +@torch.no_grad() +def build_prompt_pair_batch_from_cache( + cache: PromptEmbedsCache, + target: 'SliderTargetConfig', + neutral: Optional[str] = '', +) -> list[EncodedPromptPair]: + erase_negative = len(target.positive.strip()) == 0 + enhance_positive = len(target.negative.strip()) == 0 + + both = not erase_negative and not enhance_positive + + prompt_pair_batch = [] + + if both or erase_negative: + # print("Encoding erase negative") + prompt_pair_batch += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + weight=target.weight, + target=target + ), + ] + if both or enhance_positive: + # print("Encoding enhance positive") + prompt_pair_batch += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + weight=target.weight, + target=target + ), + ] + if both or enhance_positive: + # print("Encoding erase positive (inverse)") + prompt_pair_batch += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + both_targets=cache[f"{target.positive} {target.negative}"], + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight, + target=target + ), + ] + if both or erase_negative: + # print("Encoding enhance negative (inverse)") + prompt_pair_batch += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + both_targets=cache[f"{target.positive} {target.negative}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight, + target=target + ), + ] + + return prompt_pair_batch + + +def build_latent_image_batch_for_prompt_pair( + pos_latent, + neg_latent, + prompt_pair: EncodedPromptPair, + prompt_chunk_size +): + erase_negative = len(prompt_pair.target.positive.strip()) == 0 + enhance_positive = len(prompt_pair.target.negative.strip()) == 0 + both = not erase_negative and not enhance_positive + + prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size) + if both and len(prompt_pair_chunks) != 4: + raise Exception("Invalid prompt pair chunks") + if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2: + raise Exception("Invalid prompt pair chunks") + + latent_list = [] + + if both or erase_negative: + latent_list.append(pos_latent) + if both or enhance_positive: + latent_list.append(pos_latent) + if both or enhance_positive: + latent_list.append(neg_latent) + if both or erase_negative: + latent_list.append(neg_latent) + + return torch.cat(latent_list, dim=0) + + +def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True): + if trigger is None: + # process as empty string to remove any [trigger] tokens + trigger = '' + output_prompt = prompt + default_replacements = ["[name]", "[trigger]"] + + replace_with = trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + if trigger.strip() != "": + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + # if num_instances > 1: + # print( + # f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt diff --git a/ai-toolkit/toolkit/reference_adapter.py b/ai-toolkit/toolkit/reference_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..25c3e2f571428dabaa4f2c22f6ff2405b04e1b65 --- /dev/null +++ b/ai-toolkit/toolkit/reference_adapter.py @@ -0,0 +1,400 @@ +import math + +import torch +import sys + +from PIL import Image +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.basic import adain +from toolkit.saving import load_ip_adapter_model +from toolkit.train_tools import get_torch_dtype +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict +from collections import OrderedDict +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from diffusers import ( + EulerDiscreteScheduler, + DDPMScheduler, +) + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + + +import torch.nn.functional as F +import torch.nn as nn + + +class ReferenceAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.ref_net = nn.Linear(hidden_size, hidden_size) + self.blend = nn.Parameter(torch.zeros(hidden_size)) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self._memory = None + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + if self.adapter_ref().is_active: + if self.adapter_ref().reference_mode == "write": + # write_mode + memory_ref = self.ref_net(hidden_states) + self._memory = memory_ref + elif self.adapter_ref().reference_mode == "read": + # read_mode + if self._memory is None: + print("Warning: no memory to read from") + else: + + saved_hidden_states = self._memory + try: + new_hidden_states = saved_hidden_states + blend = self.blend + # expand the blend buyt keep dim 0 the same (batch) + while blend.ndim < new_hidden_states.ndim: + blend = blend.unsqueeze(0) + # expand batch + blend = torch.cat([blend] * new_hidden_states.shape[0], dim=0) + hidden_states = blend * new_hidden_states + (1 - blend) * hidden_states + except Exception as e: + raise Exception(f"Error blending: {e}") + + return hidden_states + + +class ReferenceAdapter(torch.nn.Module): + + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.reference_mode = "read" + self.current_scale = 1.0 + self.is_active = True + self._reference_images = None + self._reference_latents = None + self.has_memory = False + + self.noise_scheduler: Union[DDPMScheduler, EulerDiscreteScheduler] = None + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + for name in sd.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + # layer_name = name.split(".processor")[0] + # weights = { + # "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + # "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + # } + + attn_procs[name] = ReferenceAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self + ) + # attn_procs[name].load_state_dict(weights) + sd.unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + sd.adapter = self + self.unet_ref: weakref.ref = weakref.ref(sd.unet) + self.adapter_modules = adapter_modules + # load the weights if we have some + if self.config.name_or_path: + loaded_state_dict = load_ip_adapter_model( + self.config.name_or_path, + device='cpu', + dtype=sd.torch_dtype + ) + self.load_state_dict(loaded_state_dict) + + self.set_scale(1.0) + self.attach() + self.to(self.device, self.sd_ref().torch_dtype) + + # if self.config.train_image_encoder: + # self.image_encoder.train() + # self.image_encoder.requires_grad_(True) + + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + # self.image_encoder.to(*args, **kwargs) + # self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + return self + + def load_reference_adapter(self, state_dict: Union[OrderedDict, dict]): + reference_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + reference_layers.load_state_dict(state_dict["reference_adapter"]) + + # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): + # self.load_ip_adapter(state_dict) + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + state_dict["reference_adapter"] = self.adapter_modules.state_dict() + return state_dict + + def get_scale(self): + return self.current_scale + + def set_reference_images(self, reference_images: Optional[torch.Tensor]): + self._reference_images = reference_images.clone().detach() + self._reference_latents = None + self.clear_memory() + + def set_blank_reference_images(self, batch_size): + self._reference_images = torch.zeros((batch_size, 3, 512, 512), device=self.device, dtype=self.sd_ref().torch_dtype) + self._reference_latents = torch.zeros((batch_size, 4, 64, 64), device=self.device, dtype=self.sd_ref().torch_dtype) + self.clear_memory() + + + def set_scale(self, scale): + self.current_scale = scale + for attn_processor in self.sd_ref().unet.attn_processors.values(): + if isinstance(attn_processor, ReferenceAttnProcessor2_0): + attn_processor.scale = scale + + + def attach(self): + unet = self.sd_ref().unet + self._original_unet_forward = unet.forward + unet.forward = lambda *args, **kwargs: self.unet_forward(*args, **kwargs) + if self.sd_ref().network is not None: + # set network to not merge in + self.sd_ref().network.can_merge_in = False + + def unet_forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): + skip = False + if self._reference_images is None and self._reference_latents is None: + skip = True + if not self.is_active: + skip = True + + if self.has_memory: + skip = True + + if not skip: + if self.sd_ref().network is not None: + self.sd_ref().network.is_active = True + if self.sd_ref().network.is_merged_in: + raise ValueError("network is merged in, but we are not supposed to be merged in") + # send it through our forward first + self.forward(sample, timestep, encoder_hidden_states, *args, **kwargs) + + if self.sd_ref().network is not None: + self.sd_ref().network.is_active = False + + # Send it through the original unet forward + return self._original_unet_forward(sample, timestep, encoder_hidden_states, args, **kwargs) + + + # use drop for prompt dropout, or negatives + def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): + if not self.noise_scheduler: + raise ValueError("noise scheduler not set") + if not self.is_active or (self._reference_images is None and self._reference_latents is None): + raise ValueError("reference adapter not active or no reference images set") + # todo may need to handle cfg? + self.reference_mode = "write" + + if self._reference_latents is None: + self._reference_latents = self.sd_ref().encode_images(self._reference_images.to( + self.device, self.sd_ref().torch_dtype + )).detach() + # create a sample from our reference images + reference_latents = self._reference_latents.clone().detach().to(self.device, self.sd_ref().torch_dtype) + # if our num of samples are half of incoming, we are doing cfg. Zero out the first half (unconditional) + if reference_latents.shape[0] * 2 == sample.shape[0]: + # we are doing cfg + # Unconditional goes first + reference_latents = torch.cat([torch.zeros_like(reference_latents), reference_latents], dim=0).detach() + + # resize it so reference_latents will fit inside sample in the center + width_scale = sample.shape[2] / reference_latents.shape[2] + height_scale = sample.shape[3] / reference_latents.shape[3] + scale = min(width_scale, height_scale) + # resize the reference latents + + mode = "bilinear" if scale > 1.0 else "bicubic" + + reference_latents = F.interpolate( + reference_latents, + size=(int(reference_latents.shape[2] * scale), int(reference_latents.shape[3] * scale)), + mode=mode, + align_corners=False + ) + + # add 0 padding if needed + width_pad = (sample.shape[2] - reference_latents.shape[2]) / 2 + height_pad = (sample.shape[3] - reference_latents.shape[3]) / 2 + reference_latents = F.pad( + reference_latents, + (math.floor(width_pad), math.floor(width_pad), math.ceil(height_pad), math.ceil(height_pad)), + mode="constant", + value=0 + ) + + # resize again just to make sure it is exact same size + reference_latents = F.interpolate( + reference_latents, + size=(sample.shape[2], sample.shape[3]), + mode="bicubic", + align_corners=False + ) + + # todo maybe add same noise to the sample? For now we will send it through with no noise + # sample_imgs = self.noise_scheduler.add_noise(sample_imgs, timestep) + self._original_unet_forward(reference_latents, timestep, encoder_hidden_states, *args, **kwargs) + self.reference_mode = "read" + self.has_memory = True + return None + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + for attn_processor in self.adapter_modules: + yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # self.image_encoder.train() + # else: + # for attn_processor in self.adapter_modules: + # yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + # self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["reference_adapter"], strict=strict) + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True + + def clear_memory(self): + for attn_processor in self.adapter_modules: + if isinstance(attn_processor, ReferenceAttnProcessor2_0): + attn_processor._memory = None + self.has_memory = False diff --git a/ai-toolkit/toolkit/resampler.py b/ai-toolkit/toolkit/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9ace5a3a18d78e6f5b712dd587aaa45827247dc6 --- /dev/null +++ b/ai-toolkit/toolkit/resampler.py @@ -0,0 +1,160 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py +# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, + # number of latents derived from mean pooled representation of the sequence + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/ai-toolkit/toolkit/sampler.py b/ai-toolkit/toolkit/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c36af3def3ff6f9abb0011cc5cff7323fad46e0d --- /dev/null +++ b/ai-toolkit/toolkit/sampler.py @@ -0,0 +1,213 @@ +import copy +import math + +from diffusers import ( + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + LCMScheduler, + FlowMatchEulerDiscreteScheduler, +) +from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler + +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + +from k_diffusion.external import CompVisDenoiser + +from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +sd_config = { + "_class_name": "EulerAncestralDiscreteScheduler", + "_diffusers_version": "0.24.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + # "skip_prk_steps": False, # for training + "skip_prk_steps": True, + # "steps_offset": 1, + "steps_offset": 0, + # "timestep_spacing": "trailing", # for training + "timestep_spacing": "leading", + "trained_betas": None +} + +pixart_config = { + "_class_name": "DPMSolverMultistepScheduler", + "_diffusers_version": "0.22.0.dev0", + "algorithm_type": "dpmsolver++", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "dynamic_thresholding_ratio": 0.995, + "euler_at_final": False, + # "lambda_min_clipped": -Infinity, + "lambda_min_clipped": -math.inf, + "lower_order_final": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "solver_order": 2, + "solver_type": "midpoint", + "steps_offset": 0, + "thresholding": False, + "timestep_spacing": "linspace", + "trained_betas": None, + "use_karras_sigmas": False, + "use_lu_lambdas": False, + "variance_type": None +} + +flux_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.30.0.dev0", + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + +sd_flow_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.30.0.dev0", + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": False +} + +lumina2_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.33.0.dev0", + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 6.0, + "shift_terminal": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + + +def get_sampler( + sampler: str, + kwargs: dict = None, + arch: str = "sd" +): + sched_init_args = {} + if kwargs is not None: + sched_init_args.update(kwargs) + + config_to_use = copy.deepcopy(sd_config) if arch == "sd" else copy.deepcopy(pixart_config) + + if sampler.startswith("k_"): + sched_init_args["use_karras_sigmas"] = True + + if sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sampler == "ddpm": # ddpm is not supported ? + scheduler_cls = DDPMScheduler + elif sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sampler == "lms" or sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sampler == "euler" or sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sampler == "euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sampler.replace("k_", "") + elif sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sampler == "dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sampler == "dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + elif sampler == "lcm": + scheduler_cls = LCMScheduler + elif sampler == "custom_lcm": + scheduler_cls = CustomLCMScheduler + elif sampler == "mean_flow": + scheduler_cls = MeanFlowScheduler + elif sampler == "flowmatch": + scheduler_cls = CustomFlowMatchEulerDiscreteScheduler + config_to_use = copy.deepcopy(flux_config) + if arch == "sd": + config_to_use = copy.deepcopy(sd_flow_config) + elif arch == "flux": + config_to_use = copy.deepcopy(flux_config) + elif arch == "lumina2": + config_to_use = copy.deepcopy(lumina2_config) + else: + print(f"Unknown architecture {arch}, using default flux config") + # use flux by default + config_to_use = copy.deepcopy(flux_config) + else: + raise ValueError(f"Sampler {sampler} not supported") + + + config = copy.deepcopy(config_to_use) + config.update(sched_init_args) + + scheduler = scheduler_cls.from_config(config) + + return scheduler + + +# testing +if __name__ == "__main__": + from diffusers import DiffusionPipeline + + from diffusers import StableDiffusionKDiffusionPipeline + import torch + import os + + inference_steps = 25 + + pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") + pipe = pipe.to("cuda") + + k_diffusion_model = CompVisDenoiser(model) + + pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") + pipe = pipe.to("cuda") + + prompt = "an astronaut riding a horse on mars" + pipe.set_scheduler("sample_heun") + generator = torch.Generator(device="cuda").manual_seed(seed) + image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] + + image.save("./astronaut_heun_k_diffusion.png") diff --git a/ai-toolkit/toolkit/samplers/custom_flowmatch_sampler.py b/ai-toolkit/toolkit/samplers/custom_flowmatch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2e6fa7ecb33179272fc7c5ef857e5c78bb4387 --- /dev/null +++ b/ai-toolkit/toolkit/samplers/custom_flowmatch_sampler.py @@ -0,0 +1,219 @@ +import math +from typing import Union +from torch.distributions import LogNormal +from diffusers import FlowMatchEulerDiscreteScheduler +import torch +import numpy as np +from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_noise_sigma = 1.0 + self.timestep_type = "linear" + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + # Bell-Shaped Mean-Normalized Timestep Weighting + # bsmntw? need a better name + + x = torch.arange(num_timesteps, dtype=torch.float32) + y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2) + + # Shift minimum to 0 + y_shifted = y - y.min() + + # Scale to make mean 1 + bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) + + # only do half bell + hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) + + # flatten second half to max + hbsmntw_weighing[num_timesteps // + 2:] = hbsmntw_weighing[num_timesteps // 2:].max() + + # Create linear timesteps from 1000 to 1 + timesteps = torch.linspace(1000, 1, num_timesteps, device='cpu') + + self.linear_timesteps = timesteps + self.linear_timesteps_weights = bsmntw_weighing + self.linear_timesteps_weights2 = hbsmntw_weighing + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() + for t in timesteps] + + # Get the weights for the timesteps + if timestep_type == "weighted": + weights = torch.tensor( + [default_weighing_scheme[i] for i in step_indices], + device=timesteps.device, + dtype=timesteps.dtype + ) + elif v2: + weights = self.linear_timesteps_weights2[step_indices].flatten() + else: + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + t_01 = (timesteps / 1000).to(original_samples.device) + # forward ODE + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise + # reverse ODE + # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps( + self, + num_timesteps, + device, + timestep_type='linear', + latents=None, + patch_size=1 + ): + self.timestep_type = timestep_type + if timestep_type == 'linear' or timestep_type == 'weighted': + timesteps = torch.linspace(1000, 1, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + elif timestep_type == 'sigmoid': + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = ((1 - t) * 1000) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps + elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']: + # matches inference dynamic shifting + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t( + self.sigma_min), num_timesteps + ) + + sigmas = timesteps / self.config.num_train_timesteps + + if self.config.use_dynamic_shifting: + if latents is None: + raise ValueError('latents is None') + + # for flux we double up the patch size before sending her to simulate the latent reduction + h = latents.shape[2] + w = latents.shape[3] + image_seq_len = h * w // (patch_size**2) + + mu = calculate_shift( + image_seq_len, + self.config.get("base_image_seq_len", 256), + self.config.get("max_image_seq_len", 4096), + self.config.get("base_shift", 0.5), + self.config.get("max_shift", 1.16), + ) + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + + sigmas = torch.from_numpy(sigmas).to( + dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat( + [sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat( + [sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas + + self.timesteps = timesteps.to(device=device) + return timesteps + + elif timestep_type == 'lognorm_blend': + # disgtribute timestepd to the center/early and blend in linear + alpha = 0.75 + + lognormal = LogNormal(loc=0, scale=0.333) + + # Sample from the distribution + t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device) + + # Scale and reverse the values to go from 1000 to 0 + t1 = ((1 - t1/t1.max()) * 1000) + + # add half of linear + t2 = torch.linspace(1000, 1, int( + num_timesteps * (1 - alpha)), device=device) + timesteps = torch.cat((t1, t2)) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + timesteps = timesteps.to(torch.int) + self.timesteps = timesteps.to(device=device) + return timesteps + else: + raise ValueError(f"Invalid timestep type: {timestep_type}") diff --git a/ai-toolkit/toolkit/samplers/custom_lcm_scheduler.py b/ai-toolkit/toolkit/samplers/custom_lcm_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..132052af74186b2597060d66c764c8b4ed841378 --- /dev/null +++ b/ai-toolkit/toolkit/samplers/custom_lcm_scheduler.py @@ -0,0 +1,553 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + denoised: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor: + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class CustomLCMScheduler(SchedulerMixin, ConfigMixin): + """ + `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config + attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be + accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving + functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + original_inference_steps (`int`, *optional*, defaults to 50): + The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we + will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + timestep_scaling (`float`, defaults to 10.0): + The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions + `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation + error at the default of `10.0` is already pretty small). + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + self.original_inference_steps = 50 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self.train_timesteps = 1000 + + self._step_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + + @property + def step_index(self): + return self._step_index + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + strength: int = 1.0, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + original_inference_steps (`int`, *optional*): + The original number of inference steps, which will be used to generate a linearly-spaced timestep + schedule (which is different from the standard `diffusers` implementation). We will then take + `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as + our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. + """ + + original_inference_steps = self.original_inference_steps + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + original_steps = ( + original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps + ) + + if original_steps > self.config.num_train_timesteps: + raise ValueError( + f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + + # LCM Timesteps Setting + # The skipping step parameter k from the paper. + k = self.config.num_train_timesteps // original_steps + # LCM Training/Distillation Steps Schedule + # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + + if skipping_step < 1: + raise ValueError( + f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." + ) + + # LCM Inference Steps Schedule + lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from lcm_origin_timesteps. + inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = lcm_origin_timesteps[inference_indices] + + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) + + self._step_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.config.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.config.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.config.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.config.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.config.clip_sample: + predicted_original_sample = predicted_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype + ) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/ai-toolkit/toolkit/samplers/mean_flow_scheduler.py b/ai-toolkit/toolkit/samplers/mean_flow_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce22ba1718a0eb2acce9fa361491fe82949635d --- /dev/null +++ b/ai-toolkit/toolkit/samplers/mean_flow_scheduler.py @@ -0,0 +1,93 @@ +from typing import Union +from diffusers import FlowMatchEulerDiscreteScheduler +import torch +from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme + +from dataclasses import dataclass +from typing import Optional, Tuple +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_noise_sigma = 1.0 + self.timestep_type = "linear" + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # Create linear timesteps from 1000 to 1 + timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu") + + self.linear_timesteps = timesteps + pass + + def get_weights_for_timesteps( + self, timesteps: torch.Tensor, v2=False, timestep_type="linear" + ) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + weights = 1.0 + + # Get the weights for the timesteps + if timestep_type == "weighted": + weights = torch.tensor( + [default_weighing_scheme[i] for i in step_indices], + device=timesteps.device, + dtype=timesteps.dtype, + ) + + return weights + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise + return noisy_model_input + + def scale_model_input( + self, sample: torch.Tensor, timestep: Union[float, torch.Tensor] + ) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, **kwargs): + timesteps = torch.linspace(1000, 1, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + return_dict: bool = True, + **kwargs: Optional[dict], + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + + # single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ) + output = sample - model_output + if not return_dict: + return (output,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output) diff --git a/ai-toolkit/toolkit/saving.py b/ai-toolkit/toolkit/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..7abc7d5058d347b06cdcfcd4b452223e7920b346 --- /dev/null +++ b/ai-toolkit/toolkit/saving.py @@ -0,0 +1,330 @@ +import json +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Literal, Optional, Union + +import torch +from safetensors.torch import load_file, save_file + +from toolkit.train_tools import get_torch_dtype +from toolkit.paths import KEYMAPS_ROOT + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +def get_slices_from_string(s: str) -> tuple: + slice_strings = s.split(',') + slices = [eval(f"slice({component.strip()})") for component in slice_strings] + return tuple(slices) + + +def convert_state_dict_to_ldm_with_mapping( + diffusers_state_dict: 'OrderedDict', + mapping_path: str, + base_path: Union[str, None] = None, + device: str = 'cpu', + dtype: torch.dtype = torch.float32 +) -> 'OrderedDict': + converted_state_dict = OrderedDict() + + # load mapping + with open(mapping_path, 'r') as f: + mapping = json.load(f, object_pairs_hook=OrderedDict) + + # keep track of keys not matched + ldm_matched_keys = [] + diffusers_matched_keys = [] + + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map'] + ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map'] + + # load base if it exists + # the base just has come keys like timing ids and stuff diffusers doesn't have or they don't match + if base_path is not None: + converted_state_dict = load_file(base_path, device) + # convert to the right dtype + for key in converted_state_dict: + converted_state_dict[key] = converted_state_dict[key].to(device, dtype=dtype) + + # process operators first + for ldm_key in ldm_diffusers_operator_map: + # if the key cat is in the ldm key, we need to process it + if 'cat' in ldm_diffusers_operator_map[ldm_key]: + cat_list = [] + for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']: + cat_list.append(diffusers_state_dict[diffusers_key].detach()) + converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype) + diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat']) + ldm_matched_keys.append(ldm_key) + if 'slice' in ldm_diffusers_operator_map[ldm_key]: + tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]] + slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]] + converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device, + dtype=dtype) + diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice']) + ldm_matched_keys.append(ldm_key) + + # process the rest of the keys + for ldm_key in ldm_diffusers_keymap: + # if the key is in the ldm key, we need to process it + if ldm_diffusers_keymap[ldm_key] in diffusers_state_dict: + tensor = diffusers_state_dict[ldm_diffusers_keymap[ldm_key]].detach().to(device, dtype=dtype) + # see if we need to reshape + if ldm_key in ldm_diffusers_shape_map: + tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0]) + converted_state_dict[ldm_key] = tensor + diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key]) + ldm_matched_keys.append(ldm_key) + + # see if any are missing from know mapping + mapped_diffusers_keys = list(ldm_diffusers_keymap.values()) + mapped_ldm_keys = list(ldm_diffusers_keymap.keys()) + + missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys] + missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys] + + if len(missing_diffusers_keys) > 0: + print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys") + print(missing_diffusers_keys) + if len(missing_ldm_keys) > 0: + print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys") + print(missing_ldm_keys) + + return converted_state_dict + + +def get_ldm_state_dict_from_diffusers( + state_dict: 'OrderedDict', + sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega', 'sdxl_refiner'] = '2', + device='cpu', + dtype=get_torch_dtype('fp32'), +): + if sd_version == '1': + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1.json') + elif sd_version == '2': + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2.json') + elif sd_version == 'sdxl': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json') + elif sd_version == 'ssd': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json') + elif sd_version == 'vega': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega.json') + elif sd_version == 'sdxl_refiner': + # load our base + base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors') + mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json') + else: + raise ValueError(f"Invalid sd_version {sd_version}") + + # convert the state dict + return convert_state_dict_to_ldm_with_mapping( + state_dict, + mapping_path, + base_path, + device=device, + dtype=dtype + ) + + +def save_ldm_model_from_diffusers( + sd: 'StableDiffusion', + output_file: str, + meta: 'OrderedDict', + save_dtype=get_torch_dtype('fp16'), + sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2' +): + converted_state_dict = get_ldm_state_dict_from_diffusers( + sd.state_dict(), + sd_version, + device='cpu', + dtype=save_dtype + ) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def save_lora_from_diffusers( + lora_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + save_dtype=get_torch_dtype('fp16'), + sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2' +): + converted_state_dict = OrderedDict() + # only handle sxdxl for now + if sd_version != 'sdxl' and sd_version != 'ssd' and sd_version != 'vega': + raise ValueError(f"Invalid sd_version {sd_version}") + for key, value in lora_state_dict.items(): + # todo verify if this works with ssd + # test encoders share keys for some reason + if key.begins_with('lora_te'): + converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype) + else: + converted_key = key + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def save_t2i_from_diffusers( + t2i_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), +): + # todo: test compatibility with non diffusers + converted_state_dict = OrderedDict() + for key, value in t2i_state_dict.items(): + converted_state_dict[key] = value.detach().to('cpu', dtype=dtype) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_t2i_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32 +): + raw_state_dict = load_file(path_to_file, device) + converted_state_dict = OrderedDict() + for key, value in raw_state_dict.items(): + # todo see if we need to convert dict + converted_state_dict[key] = value.detach().to(device, dtype=dtype) + return converted_state_dict + + + + +def save_ip_adapter_from_diffusers( + combined_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), + direct_save: bool = False +): + # todo: test compatibility with non diffusers + + converted_state_dict = OrderedDict() + for module_name, state_dict in combined_state_dict.items(): + if direct_save: + converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype) + else: + for key, value in state_dict.items(): + converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_ip_adapter_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32, + direct_load: bool = False +): + # check if it is safetensors or checkpoint + if path_to_file.endswith('.safetensors'): + raw_state_dict = load_file(path_to_file, device) + combined_state_dict = OrderedDict() + if direct_load: + return raw_state_dict + for combo_key, value in raw_state_dict.items(): + key_split = combo_key.split('.') + module_name = key_split.pop(0) + if module_name not in combined_state_dict: + combined_state_dict[module_name] = OrderedDict() + combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype) + return combined_state_dict + else: + return torch.load(path_to_file, map_location=device) + +def load_custom_adapter_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32 +): + # check if it is safetensors or checkpoint + if path_to_file.endswith('.safetensors'): + raw_state_dict = load_file(path_to_file, device) + combined_state_dict = OrderedDict() + device = device if isinstance(device, torch.device) else torch.device(device) + dtype = dtype if isinstance(dtype, torch.dtype) else get_torch_dtype(dtype) + for combo_key, value in raw_state_dict.items(): + key_split = combo_key.split('.') + module_name = key_split.pop(0) + if module_name not in combined_state_dict: + combined_state_dict[module_name] = OrderedDict() + combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype) + return combined_state_dict + else: + return torch.load(path_to_file, map_location=device) + + +def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict': + lora_keymap = OrderedDict() + + # see if we have dual text encoders " a key that starts with conditioner.embedders.1 + has_dual_text_encoders = False + for key in model_keymap: + if key.startswith('conditioner.embedders.1'): + has_dual_text_encoders = True + break + # map through the keys and values + for key, value in model_keymap.items(): + # ignore bias weights + if key.endswith('bias'): + continue + if key.endswith('.weight'): + # remove the .weight + key = key[:-7] + if value.endswith(".weight"): + # remove the .weight + value = value[:-7] + + # unet for all + key = key.replace('model.diffusion_model', 'lora_unet') + if value.startswith('unet'): + value = f"lora_{value}" + + # text encoder + if has_dual_text_encoders: + key = key.replace('conditioner.embedders.0', 'lora_te1') + key = key.replace('conditioner.embedders.1', 'lora_te2') + if value.startswith('te0') or value.startswith('te1'): + value = f"lora_{value}" + value.replace('lora_te1', 'lora_te2') + value.replace('lora_te0', 'lora_te1') + + key = key.replace('cond_stage_model.transformer', 'lora_te') + + if value.startswith('te_'): + value = f"lora_{value}" + + # replace periods with underscores + key = key.replace('.', '_') + value = value.replace('.', '_') + + # add all the weights + lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight" + lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias" + lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight" + lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias" + lora_keymap[f"{key}.alpha"] = f"{value}.alpha" + + return lora_keymap diff --git a/ai-toolkit/toolkit/scheduler.py b/ai-toolkit/toolkit/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f8f61aeb8f63b12ee8f8f2800385c11ec3b7bf --- /dev/null +++ b/ai-toolkit/toolkit/scheduler.py @@ -0,0 +1,57 @@ +import torch +from typing import Optional +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup + + +def get_lr_scheduler( + name: Optional[str], + optimizer: torch.optim.Optimizer, + **kwargs, +): + if name == "cosine": + if 'total_iters' in kwargs: + kwargs['T_max'] = kwargs.pop('total_iters') + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, **kwargs + ) + elif name == "cosine_with_restarts": + if 'total_iters' in kwargs: + kwargs['T_0'] = kwargs.pop('total_iters') + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, **kwargs + ) + elif name == "step": + + return torch.optim.lr_scheduler.StepLR( + optimizer, **kwargs + ) + elif name == "constant": + if 'factor' not in kwargs: + kwargs['factor'] = 1.0 + + return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs) + elif name == "linear": + + return torch.optim.lr_scheduler.LinearLR( + optimizer, **kwargs + ) + elif name == 'constant_with_warmup': + # see if num_warmup_steps is in kwargs + if 'num_warmup_steps' not in kwargs: + print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000") + kwargs['num_warmup_steps'] = 1000 + del kwargs['total_iters'] + return get_constant_schedule_with_warmup(optimizer, **kwargs) + else: + # try to use a diffusers scheduler + print(f"Trying to use diffusers scheduler {name}") + try: + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **kwargs) + except Exception as e: + print(e) + pass + raise ValueError( + "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" + ) diff --git a/ai-toolkit/toolkit/sd_device_states_presets.py b/ai-toolkit/toolkit/sd_device_states_presets.py new file mode 100644 index 0000000000000000000000000000000000000000..1eeecc323fefb7b06fdaf30ff9f80f5399b0ce09 --- /dev/null +++ b/ai-toolkit/toolkit/sd_device_states_presets.py @@ -0,0 +1,107 @@ +from typing import Union + +import torch +import copy + +empty_preset = { + 'vae': { + 'training': False, + 'device': 'cpu', + }, + 'unet': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'text_encoder': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'adapter': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'refiner_unet': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, +} + + +def get_train_sd_device_state_preset( + device: Union[str, torch.device], + train_unet: bool = False, + train_text_encoder: bool = False, + cached_latents: bool = False, + train_lora: bool = False, + train_adapter: bool = False, + train_embedding: bool = False, + train_decorator: bool = False, + train_refiner: bool = False, + unload_text_encoder: bool = False, + require_grads: bool = True, +): + preset = copy.deepcopy(empty_preset) + if not cached_latents: + preset['vae']['device'] = device + + if train_unet: + preset['unet']['training'] = True + preset['unet']['requires_grad'] = require_grads + preset['unet']['device'] = device + else: + preset['unet']['device'] = device + + if train_text_encoder: + preset['text_encoder']['training'] = True + preset['text_encoder']['requires_grad'] = require_grads + preset['text_encoder']['device'] = device + else: + preset['text_encoder']['device'] = device + + if train_embedding: + preset['text_encoder']['training'] = True + preset['text_encoder']['requires_grad'] = require_grads + preset['text_encoder']['training'] = True + preset['unet']['training'] = True + + if train_refiner: + preset['refiner_unet']['training'] = True + preset['refiner_unet']['requires_grad'] = require_grads + preset['refiner_unet']['device'] = device + # if not training unet, move that to cpu + if not train_unet: + preset['unet']['device'] = 'cpu' + + if train_lora: + # preset['text_encoder']['requires_grad'] = False + preset['unet']['requires_grad'] = False + if train_refiner: + preset['refiner_unet']['requires_grad'] = False + + if train_adapter: + preset['adapter']['requires_grad'] = require_grads + preset['adapter']['training'] = True + preset['adapter']['device'] = device + preset['unet']['training'] = True + preset['unet']['requires_grad'] = False + preset['unet']['device'] = device + preset['text_encoder']['device'] = device + + if train_decorator: + preset['text_encoder']['training'] = False + preset['text_encoder']['requires_grad'] = False + preset['text_encoder']['device'] = device + preset['unet']['training'] = True + preset['unet']['requires_grad'] = False + preset['unet']['device'] = device + + if unload_text_encoder: + preset['text_encoder']['training'] = False + preset['text_encoder']['requires_grad'] = False + preset['text_encoder']['device'] = 'cpu' + + return preset diff --git a/ai-toolkit/toolkit/stable_diffusion_model.py b/ai-toolkit/toolkit/stable_diffusion_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a51e9b5593c7a33507073cb2c0fa0e3f4e817839 --- /dev/null +++ b/ai-toolkit/toolkit/stable_diffusion_model.py @@ -0,0 +1,3173 @@ +import copy +import gc +import json +import random +import shutil +import typing +from typing import Optional, Union, List, Literal, Iterator +import sys +import os +from collections import OrderedDict +import copy +import yaml +from PIL import Image +from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ + ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from safetensors.torch import save_file, load_file +from torch import autocast +from torch.nn import Parameter +from torch.utils.checkpoint import checkpoint +from tqdm import tqdm +from torchvision.transforms import Resize, transforms + +from toolkit.assistant_lora import load_assistant_lora_from_path +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.ip_adapter import IPAdapter +from toolkit.util.vae import load_vae +from toolkit import train_tools +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.decorator import Decorator +from toolkit.paths import KEYMAPS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.sampler import get_sampler +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers +from toolkit.sd_device_states_presets import empty_preset +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +from einops import rearrange, repeat +import torch +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ + StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \ + FluxAdvancedControlPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ + StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ + StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Pipeline, \ + FluxControlPipeline, Lumina2Transformer2DModel +import diffusers +from diffusers import \ + AutoencoderKL, \ + UNet2DConditionModel +from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline +from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + +from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT +from huggingface_hub import hf_hub_download +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance + +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize, get_qtype +from toolkit.accelerator import get_accelerator, unwrap_model +from typing import TYPE_CHECKING +from toolkit.print import print_acc +from diffusers import FluxFillPipeline +from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel +from toolkit.basic import flush + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) + +SD_PREFIX_VAE = "vae" +SD_PREFIX_UNET = "unet" +SD_PREFIX_REFINER_UNET = "refiner_unet" +SD_PREFIX_TEXT_ENCODER = "te" + +SD_PREFIX_TEXT_ENCODER1 = "te0" +SD_PREFIX_TEXT_ENCODER2 = "te1" + +# prefixed diffusers keys +DO_NOT_TRAIN_WEIGHTS = [ + "unet_time_embedding.linear_1.bias", + "unet_time_embedding.linear_1.weight", + "unet_time_embedding.linear_2.bias", + "unet_time_embedding.linear_2.weight", + "refiner_unet_time_embedding.linear_1.bias", + "refiner_unet_time_embedding.linear_1.weight", + "refiner_unet_time_embedding.linear_2.bias", + "refiner_unet_time_embedding.linear_2.weight", +] + +DeviceStatePreset = Literal['cache_latents', 'generate'] + + +class BlankNetwork: + + def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_merged_in = False + self.can_merge_in = False + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + def train(self): + pass + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + + + +class StableDiffusion: + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None, + noise_scheduler=None, + quantize_device=None, + ): + self.accelerator = get_accelerator() + self.custom_pipeline = custom_pipeline + self.device = str(device) + if "cuda" in self.device and ":" not in self.device: + self.device = f"{self.device}:0" + self.device_torch = torch.device(device) + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + + self.vae_device_torch = torch.device(device) + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = torch.device(device) + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + self.arch = model_config.arch + + self.device_state = None + + self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.unet: Union[None, 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler + + self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + self.is_loaded = False + + # to hold network if there is one + self.network = None + self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None + self.decorator: Union[Decorator, None] = None + self.arch: ModelArch = model_config.arch + # self.is_xl = model_config.is_xl + # self.is_v2 = model_config.is_v2 + # self.is_ssd = model_config.is_ssd + # self.is_v3 = model_config.is_v3 + # self.is_vega = model_config.is_vega + # self.is_pixart = model_config.is_pixart + # self.is_auraflow = model_config.is_auraflow + # self.is_flux = model_config.is_flux + # self.is_lumina2 = model_config.is_lumina2 + + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + + self.config_file = None + + self.is_flow_matching = False + if self.is_flux or self.is_v3 or self.is_auraflow or self.is_lumina2 or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): + self.is_flow_matching = True + + self.quantize_device = self.device_torch + self.low_vram = self.model_config.low_vram + + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + self._after_sample_img_hooks = [] + self._status_update_hooks = [] + # todo update this based on the model + self.is_transformer = False + + self.sample_prompts_cache = None + + self.is_multistage = False + # a list of multistage boundaries starting with train step 1000 to first idx + self.multistage_boundaries: List[float] = [0.0] + # a list of trainable multistage boundaries + self.trainable_multistage_boundaries: List[int] = [0] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False + # defines if the model supports model paths. Only some will + self.supports_model_paths = False + + # use new lokr format (default false for old models for backwards compatibility) + self.use_old_lokr_format = True + + # when padding to make batch size work, which side padding to use, right or left + # some llms need left side padding, others need right side + self.te_padding_side = "right" + + # can be used on models to invalidate cache if things change. + self.latent_space_version = None + + # if a mask is passed, do the loss with the mask. May be set false for models that use a mask for other reasons. + self.do_masked_loss = True + + # properties for old arch for backwards compatibility + @property + def is_xl(self): + return self.arch == 'sdxl' + + @property + def is_v2(self): + return self.arch == 'sd2' + + @property + def is_ssd(self): + return self.arch == 'ssd' + + @property + def is_v3(self): + return self.arch == 'sd3' + + @property + def is_vega(self): + return self.arch == 'vega' + + @property + def is_pixart(self): + return self.arch == 'pixart' + + @property + def is_auraflow(self): + return self.arch == 'auraflow' + + @property + def is_flux(self): + return self.arch == 'flux' + + @property + def is_lumina2(self): + return self.arch == 'lumina2' + + @property + def text_embedding_space_version(self): + return self.arch + + @property + def unet_unwrapped(self): + return unwrap_model(self.unet) + + def get_bucket_divisibility(self): + if self.vae is None: + return 16 + divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + # flux packs this again, + if self.is_flux or self.is_v3: + divisibility = divisibility * 2 + return divisibility * 2 # todo remove this + + + def load_model(self): + if self.is_loaded: + return + dtype = get_torch_dtype(self.dtype) + + # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why + # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) + # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) + # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) + + model_path = self.model_config.name_or_path + if 'civitai.com' in self.model_config.name_or_path: + # load is a civit ai model, use the loader. + from toolkit.civitai import get_model_path_from_url + model_path = get_model_path_from_url(self.model_config.name_or_path) + + load_args = {} + if self.noise_scheduler: + load_args['scheduler'] = self.noise_scheduler + + if self.model_config.vae_path is not None: + load_args['vae'] = load_vae(self.model_config.vae_path, dtype) + if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusionXLPipeline + # pipln = StableDiffusionKDiffusionXLPipeline + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + # variant="fp16", + use_safetensors=True, + **load_args + ) + else: + pipe = pipln.from_single_file( + model_path, + device=self.device_torch, + torch_dtype=self.torch_dtype, + ) + + if 'vae' in load_args and load_args['vae'] is not None: + pipe.vae = load_args['vae'] + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + for text_encoder in text_encoders: + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + + if self.model_config.experimental_xl: + print_acc("Experimental XL mode enabled") + print_acc("Loading and injecting alt weights") + # load the mismatched weight and force it in + raw_state_dict = load_file(model_path) + replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() + del raw_state_dict + # get state dict for for 2nd text encoder + te1_state_dict = text_encoders[1].state_dict() + # replace weight with mismatched weight + te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) + flush() + print_acc("Injecting alt weights") + elif self.model_config.is_v3: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusion3Pipeline + + print_acc("Loading SD3 model") + # assume it is the large model + base_model_path = "stabilityai/stable-diffusion-3.5-large" + print_acc("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + else: + # is remote use whatever path we were given + base_model_path = model_path + + transformer = SD3Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.lora_path is not None: + raise ValueError("LoRA is not supported for SD3 models currently") + + if self.model_config.quantize: + quantization_type = get_qtype(self.model_config.qtype) + print_acc("Quantizing transformer") + quantize(transformer, weights=quantization_type) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print_acc("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print_acc("Loading t5") + tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) + text_encoder_3 = T5EncoderModel.from_pretrained( + base_model_path, + subfolder="text_encoder_3", + torch_dtype=dtype + ) + + text_encoder_3.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print_acc("Quantizing T5") + quantize(text_encoder_3, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder_3) + flush() + + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + try: + # try to load with default diffusers + pipe = pipln.from_pretrained( + base_model_path, + dtype=dtype, + device=self.device_torch, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + transformer=transformer, + # variant="fp16", + use_safetensors=True, + repo_type="model", + ignore_patterns=["*.md", "*..gitattributes"], + **load_args + ) + except Exception as e: + print_acc(f"Error loading from pretrained: {e}") + raise e + + else: + pipe = pipln.from_single_file( + model_path, + transformer=transformer, + device=self.device_torch, + torch_dtype=self.torch_dtype, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + **load_args + ) + + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] + # replace the to function with a no-op since it throws an error instead of a warning + # text_encoders[2].to = lambda *args, **kwargs: None + for text_encoder in text_encoders: + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + + elif self.model_config.is_pixart: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" + if self.model_config.is_pixart_sigma: + main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = T5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + + if self.model_config.is_pixart_sigma: + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + else: + + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, + subfolder=subfolder) + pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ).to(self.device_torch) + + if self.model_config.unet_sample_size is not None: + pipe.transformer.config.sample_size = self.model_config.unet_sample_size + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer + + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + if self.noise_scheduler is None: + self.noise_scheduler = pipe.scheduler + + + elif self.model_config.is_auraflow: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = UMT5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + # load the transformer only from the save + transformer = AuraFlowTransformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + # patch auraflow so it can handle other aspect ratios + # patch_auraflow_pos_embed(pipe.transformer.pos_embed) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer + + elif self.model_config.is_flux: + self.print_and_status_update("Loading Flux model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + self.print_and_status_update("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + local_files_only = False + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + # low_cpu_mem_usage=False, + # device_map=None + ) + # hack in model gpu splitter + if self.model_config.split_model_over_gpus: + add_model_gpu_splitter_to_flux( + transformer, + other_module_param_count_scale=self.model_config.split_model_other_module_param_count_scale + ) + + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: + raise ValueError("Cannot load both assistant lora and inference lora at the same time") + + if self.model_config.lora_path: + raise ValueError("Cannot load both assistant lora and lora at the same time") + + if not self.is_flux: + raise ValueError("Assistant/ inference lora is only supported for flux models currently") + + load_lora_path = self.model_config.inference_lora_path + if load_lora_path is None: + load_lora_path = self.model_config.assistant_lora_path + + if os.path.isdir(load_lora_path): + load_lora_path = os.path.join( + load_lora_path, "pytorch_lora_weights.safetensors" + ) + elif not os.path.exists(load_lora_path): + print_acc(f"Grabbing lora from the hub: {load_lora_path}") + new_lora_path = hf_hub_download( + load_lora_path, + filename="pytorch_lora_weights.safetensors" + ) + # replace the path + load_lora_path = new_lora_path + + if self.model_config.inference_lora_path is not None: + self.model_config.inference_lora_path = new_lora_path + if self.model_config.assistant_lora_path is not None: + self.model_config.assistant_lora_path = new_lora_path + + if self.model_config.assistant_lora_path is not None: + # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on + # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps + # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half + # so we will merge in now and sample with -1 weight later + self.invert_assistant_lora = True + # trigger it to get merged in + self.model_config.lora_path = self.model_config.assistant_lora_path + + if self.model_config.lora_path is not None: + print_acc("Fusing in LoRA") + # need the pipe for peft + pipe: FluxPipeline = FluxPipeline( + scheduler=None, + text_encoder=None, + tokenizer=None, + text_encoder_2=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + ) + if self.low_vram: + # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts + # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu + # we are going to separate it into the two transformer blocks one at a time + + lora_state_dict = load_file(self.model_config.lora_path) + single_transformer_lora = {} + single_block_key = "transformer.single_transformer_blocks." + double_transformer_lora = {} + double_block_key = "transformer.transformer_blocks." + for key, value in lora_state_dict.items(): + if single_block_key in key: + single_transformer_lora[key] = value + elif double_block_key in key: + double_transformer_lora[key] = value + else: + raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") + + # double blocks + transformer.transformer_blocks = transformer.transformer_blocks.to( + self.quantize_device, dtype=dtype + ) + pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") + pipe.fuse_lora() + pipe.unload_lora_weights() + transformer.transformer_blocks = transformer.transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # single blocks + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + self.quantize_device, dtype=dtype + ) + pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") + pipe.fuse_lora() + pipe.unload_lora_weights() + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # cleanup + del single_transformer_lora + del double_transformer_lora + del lora_state_dict + flush() + + else: + # need the pipe to do this unfortunately for now + # we have to fuse in the weights before quantizing + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + flush() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + self.print_and_status_update("Loading VAE") + if self.model_config.vae_path is not None: + vae = load_vae(self.model_config.vae_path, dtype) + else: + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", + torch_dtype=dtype) + + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + self.print_and_status_update("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + Pipe = FluxPipeline + + pipe: Pipe = Pipe( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + elif self.model_config.is_lumina2: + self.print_and_status_update("Loading Lumina2 model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + self.print_and_status_update("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = Lumina2Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError("Splitting model over gpus is not supported for Lumina2 models") + + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError("Assistant LoRA is not supported for Lumina2 models currently") + + if self.model_config.lora_path is not None: + raise ValueError("Loading LoRA is not supported for Lumina2 models currently") + + flush() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + self.print_and_status_update("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + if self.model_config.te_name_or_path is not None: + self.print_and_status_update("Loading TE") + tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype) + text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype) + else: + self.print_and_status_update("Loading Gemma2") + tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Gemma2") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) + freeze(text_encoder) + flush() + + self.print_and_status_update("Making pipe") + pipe: Lumina2Pipeline = Lumina2Pipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + else: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusionPipeline + + if self.model_config.text_encoder_bits < 16: + # this is only supported for T5 models for now + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + torch_dtype=self.te_torch_dtype, + **te_kwargs + ) + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + load_args['text_encoder'] = text_encoder + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + safety_checker=None, + # variant="fp16", + trust_remote_code=True, + **load_args + ) + else: + pipe = pipln.from_single_file( + model_path, + dtype=dtype, + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + torch_dtype=self.torch_dtype, + safety_checker=None, + trust_remote_code=True, + **load_args + ) + flush() + + pipe.register_to_config(requires_safety_checker=False) + text_encoder = pipe.text_encoder + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + tokenizer = pipe.tokenizer + + # scheduler doesn't get set sometimes, so we set it here + pipe.scheduler = self.noise_scheduler + + # add hacks to unet to help training + # pipe.unet = prepare_unet_for_training(pipe.unet) + + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: + # pixart and sd3 dont use a unet + self.unet = pipe.transformer + else: + self.unet: 'UNet2DConditionModel' = pipe.unet + self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + self.vae.eval() + self.vae.requires_grad_(False) + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + self.vae_scale_factor = VAE_SCALE_FACTOR + self.unet.to(self.device_torch, dtype=dtype) + self.unet.requires_grad_(False) + self.unet.eval() + + # load any loras we have + if self.model_config.lora_path is not None and not self.is_flux and not self.is_lumina2: + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.pipeline = pipe + self.load_refiner() + self.is_loaded = True + + if self.model_config.assistant_lora_path is not None: + print_acc("Loading assistant lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.assistant_lora_path, self) + + if self.invert_assistant_lora: + # invert and disable during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print_acc("Loading inference lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.inference_lora_path, self) + # disable during training + self.assistant_lora.is_active = False + + if self.is_pixart and self.vae_scale_factor == 16: + # TODO make our own pipeline? + # we generate an image 2x larger, so we need to copy the sizes from larger ones down + # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN + for key in ASPECT_RATIO_256_BIN.keys(): + ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2] + for key in ASPECT_RATIO_512_BIN.keys(): + ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2] + for key in ASPECT_RATIO_1024_BIN.keys(): + ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2] + for key in ASPECT_RATIO_2048_BIN.keys(): + ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2] + + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + else: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + else: + self.text_encoder.eval() + + def load_refiner(self): + # for now, we are just going to rely on the TE from the base model + # which is TE2 for SDXL and TE for SD (no refiner currently) + # and completely ignore a TE that may or may not be packaged with the refiner + if self.model_config.refiner_name_or_path is not None: + refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') + # load the refiner model + dtype = get_torch_dtype(self.dtype) + model_path = self.model_config.refiner_name_or_path + if not os.path.exists(model_path) or os.path.isdir(model_path): + # TODO only load unet?? + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + # variant="fp16", + use_safetensors=True, + ).to(self.device_torch) + else: + refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( + model_path, + dtype=dtype, + device=self.device_torch, + torch_dtype=self.torch_dtype, + original_config_file=refiner_config_path, + ).to(self.device_torch) + + self.refiner_unet = refiner.unet + del refiner + flush() + + def _after_sample_image(self, img_num, total_imgs): + # process all hooks + for hook in self._after_sample_img_hooks: + hook(img_num, total_imgs) + + def add_after_sample_image_hook(self, func): + self._after_sample_img_hooks.append(func) + + def _status_update(self, status: str): + for hook in self._status_update_hooks: + hook(status) + + def print_and_status_update(self, status: str): + print_acc(status) + self._status_update(status) + + def add_status_update_hook(self, func): + self._status_update_hooks.append(func) + + @torch.no_grad() + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, + ): + network = unwrap_model(self.network) + merge_multiplier = 1.0 + flush() + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print_acc("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print_acc("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + + if network is not None: + network.eval() + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set([x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1 and network.can_merge_in: + # make sure it is on device before merging. + self.unet.to(self.device_torch) + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) + else: + network = BlankNetwork() + + self.save_device_state() + self.set_device_state_preset('generate') + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + if pipeline is None: + noise_scheduler = self.noise_scheduler + if sampler is not None: + if sampler.startswith("sample_"): # sample_dpmpp_2m + # using ksampler + noise_scheduler = get_sampler( + 'lms', { + "prediction_type": self.prediction_type, + }) + else: + arch = 'sd' + if self.is_pixart: + arch = 'pixart' + if self.is_flux: + arch = 'flux' + if self.is_lumina2: + arch = 'lumina2' + noise_scheduler = get_sampler( + sampler, + { + "prediction_type": self.prediction_type, + }, + arch=arch + ) + + try: + noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) + except: + pass + + if sampler.startswith("sample_") and self.is_xl: + # using kdiffusion + Pipe = StableDiffusionKDiffusionXLPipeline + elif self.is_xl: + Pipe = StableDiffusionXLPipeline + elif self.is_v3: + Pipe = StableDiffusion3Pipeline + else: + Pipe = StableDiffusionPipeline + + extra_args = {} + if self.adapter is not None: + if isinstance(self.adapter, T2IAdapter): + if self.is_xl: + Pipe = StableDiffusionXLAdapterPipeline + else: + Pipe = StableDiffusionAdapterPipeline + extra_args['adapter'] = self.adapter + elif isinstance(self.adapter, ControlNetModel): + if self.is_xl: + Pipe = StableDiffusionXLControlNetPipeline + else: + Pipe = StableDiffusionControlNetPipeline + extra_args['controlnet'] = self.adapter + elif isinstance(self.adapter, ReferenceAdapter): + # pass the noise scheduler to the adapter + self.adapter.noise_scheduler = noise_scheduler + else: + if self.is_xl: + extra_args['add_watermarker'] = False + + # TODO add clip skip + if self.is_xl: + pipeline = Pipe( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ).to(self.device_torch) + pipeline.watermark = None + elif self.is_flux: + if self.model_config.use_flux_cfg: + pipeline = FluxWithCFGPipeline( + vae=self.vae, + transformer=unwrap_model(self.unet), + text_encoder=unwrap_model(self.text_encoder[0]), + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ) + + else: + Pipe = FluxPipeline + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # see if it is a control lora + if self.adapter.control_lora is not None: + Pipe = FluxAdvancedControlPipeline + extra_args['do_inpainting'] = self.adapter.config.has_inpainting_input + extra_args['num_controls'] = self.adapter.config.num_control_images + + pipeline = Pipe( + vae=self.vae, + transformer=unwrap_model(self.unet), + text_encoder=unwrap_model(self.text_encoder[0]), + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ) + + pipeline.watermark = None + elif self.is_lumina2: + pipeline = Lumina2Pipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + elif self.is_v3: + pipeline = Pipe( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + text_encoder_3=self.text_encoder[2], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + tokenizer_3=self.tokenizer[2], + scheduler=noise_scheduler, + **extra_args + ) + elif self.is_pixart: + pipeline = PixArtSigmaPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + + elif self.is_auraflow: + pipeline = AuraFlowPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + + else: + pipeline = Pipe( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + **extra_args + ) + flush() + # disable progress bar + pipeline.set_progress_bar_config(disable=True) + + if sampler.startswith("sample_"): + pipeline.set_scheduler(sampler) + + refiner_pipeline = None + if self.refiner_unet: + # build refiner pipeline + refiner_pipeline = StableDiffusionXLImg2ImgPipeline( + vae=pipeline.vae, + unet=self.refiner_unet, + text_encoder=None, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=None, + tokenizer_2=pipeline.tokenizer_2, + scheduler=pipeline.scheduler, + add_watermarker=False, + requires_aesthetics_score=True, + ).to(self.device_torch) + # refiner_pipeline.register_to_config(requires_aesthetics_score=False) + refiner_pipeline.watermark = None + refiner_pipeline.set_progress_bar_config(disable=True) + flush() + + start_multiplier = 1.0 + if network is not None: + start_multiplier = network.multiplier + + # pipeline.to(self.device_torch) + + with network: + with torch.no_grad(): + if network is not None: + assert network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] + + extra = {} + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path) + # if the name doesnt have .inpainting. in it, make sure it is rgb + if ".inpaint." not in gen_config.adapter_image_path: + validation_image = validation_image.convert("RGB") + else: + # make sure it has an alpha + if validation_image.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['control_image'] = validation_image + extra['control_image_idx'] = gen_config.ctrl_idx + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if network is not None: + network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + if self.sample_prompts_cache is not None: + conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) + unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) + else: + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values(extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError("Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + + if self.is_xl: + # fix guidance rescale for sdxl + # was trained on 0.7 (I believe) + + grs = gen_config.guidance_rescale + # if grs is None or grs < 0.00001: + # grs = 0.7 + # grs = 0.0 + + if sampler.startswith("sample_"): + extra['use_karras_sigmas'] = True + extra = { + **extra, + **gen_config.extra_kwargs, + } + + img = pipeline( + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + # negative_prompt=gen_config.negative_prompt, + # negative_prompt_2=gen_config.negative_prompt_2, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=grs, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_v3: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_flux: + if self.model_config.use_flux_cfg: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + else: + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"] + if latents.dtype != self.unet.dtype: + latents = latents.to(self.unet.dtype) + return {"latents": latents} + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + # negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra + ).images[0] + elif self.is_lumina2: + pipeline: Lumina2Pipeline = pipeline + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_pixart: + # needs attention masks for some reason + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + elif self.is_auraflow: + pipeline: AuraFlowPipeline = pipeline + + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + else: + img = pipeline( + # prompt=gen_config.prompt, + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # slide off just the last 1280 on the last dim as refiner does not use first text encoder + # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ + refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] + refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:] + # run through refiner + img = refiner_pipeline( + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + + # slice these as it does not use both text encoders + # height=gen_config.height, + # width=gen_config.width, + prompt_embeds=refiner_text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=refiner_unconditional_text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=grs, + denoising_start=gen_config.refiner_start_at, + denoising_end=gen_config.num_inference_steps, + image=img.unsqueeze(0), + generator=generator, + ).images[0] + + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) + flush() + + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + # clear pipeline and cache to reduce vram usage + del pipeline + if refiner_pipeline is not None: + del refiner_pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.restore_device_state() + if network is not None: + network.train() + network.multiplier = start_multiplier + + self.unet.to(self.device_torch, dtype=self.torch_dtype) + if network.is_merged_in: + network.merge_out(merge_multiplier) + # self.tokenizer.to(original_device_dict['tokenizer']) + + # refuse loras + if self.model_config.assistant_lora_path is not None: + print_acc("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print_acc("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + + flush() + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + num_channels=None, + ): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + if num_channels is None: + num_channels = self.unet_unwrapped.config['in_channels'] + if self.is_flux: + # it gets packed, unpack it + num_channels = num_channels // 4 + noise = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ), + device=self.unet.device, + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def get_latent_noise_from_latents( + self, + latents: torch.Tensor, + noise_offset=0.0 + ): + noise = torch.randn_like(latents) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + if self.is_xl: + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + if requires_aesthetic_score: + # refiner + # https://huggingface.co/papers/2307.01952 + aesthetic_score = 6.0 # simulate one + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids + else: + return None + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + **kwargs, + ) -> torch.FloatTensor: + original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) + noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + noisy_latents_chunks = [] + + for idx in range(original_samples.shape[0]): + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) + noisy_latents_chunks.append(noisy_latents) + + noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + return noisy_latents + + def predict_noise( + self, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, + guidance_scale=7.5, + guidance_rescale=0, + add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=False, + detach_unconditional=False, + rescale_cfg=None, + return_conditional_pred=False, + guidance_embedding_scale=1.0, + bypass_guidance_embedding=False, + batch: Union[None, 'DataLoaderBatchDTO'] = None, + **kwargs, + ): + conditional_pred = None + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError("Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") + latents = latents.to(self.device_torch) + text_embeddings = text_embeddings.to(self.device_torch) + timestep = timestep.to(self.device_torch) + + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) + + def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) + ) + return torch.cat(out_chunks, dim=0) + + if self.is_xl: + with torch.no_grad(): + # 16, 6 for bs of 4 + if add_time_ids is None: + add_time_ids = self.get_time_ids_from_latents(latents) + + if do_classifier_free_guidance: + # todo check this with larget batches + add_time_ids = torch.cat([add_time_ids] * 2) + + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + # todo can we zero here the second text encoder? or match a blank string? + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + if self.model_config.refiner_name_or_path is not None: + # we have the refiner on the second half of everything. Do Both + if do_classifier_free_guidance: + raise ValueError("Refiner is not supported with classifier free guidance") + + if self.unet.training: + input_chunks = torch.chunk(latent_model_input, 2, dim=0) + timestep_chunks = torch.chunk(timestep, 2, dim=0) + added_cond_kwargs_chunked = { + "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0), + "time_ids": torch.chunk(add_time_ids, 2, dim=0), + } + text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0) + + # predict the noise residual + base_pred = self.unet( + input_chunks[0], + timestep_chunks[0], + encoder_hidden_states=text_embeds_chunks[0], + added_cond_kwargs={ + "text_embeds": added_cond_kwargs_chunked['text_embeds'][0], + "time_ids": added_cond_kwargs_chunked['time_ids'][0], + }, + **kwargs, + ).sample + + refiner_pred = self.refiner_unet( + input_chunks[1], + timestep_chunks[1], + encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], + # just use the first second text encoder + added_cond_kwargs={ + "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], + # "time_ids": added_cond_kwargs_chunked['time_ids'][1], + "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True), + }, + **kwargs, + ).sample + + noise_pred = torch.cat([base_pred, refiner_pred], dim=0) + else: + noise_pred = self.refiner_unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:], + # just use the first second text encoder + added_cond_kwargs={ + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": self.get_time_ids_from_latents(latent_model_input, + requires_aesthetic_score=True), + }, + **kwargs, + ).sample + + else: + + # predict the noise residual + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + **kwargs, + ).sample + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + conditional_pred = noise_pred_text + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + else: + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2, dim=0) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input(latent_model_input, timestep) + + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat([timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2, dim=0) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + + # predict the noise residual + if self.is_pixart: + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + batch_size, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + if self.pipeline.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.pipeline.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.pipeline.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.pipeline.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") + orig_height, orig_width = height, width + height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, + ratios=aspect_ratio_bin) + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.unet_unwrapped.config.sample_size == 128 or ( + self.vae_scale_factor == 16 and self.unet_unwrapped.config.sample_size == 64): + resolution = torch.tensor([height, width]).repeat(batch_size, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) + resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + encoder_hidden_states=text_embeddings.text_embeds, + encoder_attention_mask=text_embeddings.attention_mask, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + **kwargs + )[0] + + # learned sigma + if self.unet_unwrapped.config.out_channels // 2 == self.unet_unwrapped.config.in_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + else: + if self.unet.device != self.device_torch: + try: + self.unet.to(self.device_torch) + except Exception as e: + pass + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) + if self.is_flux: + with torch.no_grad(): + + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) + + txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet_unwrapped.config.guidance_embeds: + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch) + else: + guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) + + cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # todo make sure this doesnt change + timestep=timestep / 1000, # timestep is 1000 scale + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), + # [1, 512, 4096] + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] + txt_ids=txt_ids, # [1, 512, 3] + img_ids=img_ids, # [1, 4096, 3] + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + # c=latent_model_input.shape[1], + c=self.vae.config.latent_channels + ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) + elif self.is_lumina2: + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps + with self.accelerator.autocast(): + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=t, + encoder_attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + + # lumina2 does this before stepping. Should we do it here? + noise_pred = -noise_pred + elif self.is_v3: + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + elif self.is_auraflow: + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) + t = t.to(self.device_torch, self.torch_dtype) + + noise_pred = self.unet( + latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + timestep=t, + return_dict=False, + )[0] + else: + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if return_conditional_pred: + return noise_pred, conditional_pred + return noise_pred + + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler + # // sometimes they are on the wrong device, no idea why + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): + try: + noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) + except Exception as e: + pass + + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True + out_chunks.append( + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, + return_first_prediction=False, + bypass_guidance_embedding=False, + **kwargs, + ): + timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + + first_prediction = None + + for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) + noise_pred, conditional_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, + return_conditional_pred=True, + bypass_guidance_embedding=bypass_guidance_embedding, + **kwargs, + ) + # some schedulers need to run separately, so do that. (euler for example) + + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + + latents = self.step_scheduler(noise_pred, latents, timestep) + + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) + + # only skip first scaling + is_input_scaled = False + + # return latents_steps + if return_first_prediction: + return latents, first_prediction + return latents + + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None, + dropout_prob=0.0, + control_images=None, + ) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] + if self.is_xl: + # todo make this a config + # 50% chance to use an encoder anyway even if it is disabled + # allows the other TE to compensate for the disabled one + # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 + # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 + use_encoder_1 = True + use_encoder_2 = True + + return PromptEmbeds( + train_tools.encode_prompts_xl( + self.tokenizer, + self.text_encoder, + prompt, + prompt2, + num_images_per_prompt=num_images_per_prompt, + use_text_encoder_1=use_encoder_1, + use_text_encoder_2=use_encoder_2, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + ) + ) + if self.is_v3: + return PromptEmbeds( + train_tools.encode_prompts_sd3( + self.tokenizer, + self.text_encoder, + prompt, + num_images_per_prompt=num_images_per_prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + pipeline=self.pipeline, + ) + ) + elif self.is_pixart: + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=300 if self.model_config.is_pixart_sigma else 120, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, + ) + elif self.is_auraflow: + embeds, attention_mask = train_tools.encode_prompts_auraflow( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, # not used + ) + elif self.is_flux: + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, # list + self.text_encoder, # list + prompt, + truncate=not long_prompts, + max_length=512, + dropout_prob=dropout_prob, + attn_mask=self.model_config.attn_masking + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + elif self.is_lumina2: + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + num_images_per_prompt=1, + device=self.device_torch, + max_sequence_length=256, # should it be 512? + ) + return PromptEmbeds( + prompt_embeds, + attention_mask=prompt_attention_mask, + ) + + elif isinstance(self.text_encoder, T5EncoderModel): + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + + # just mask the attention mask + prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) + embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) + return PromptEmbeds( + embeds, + + # do we want attn mask here? + # attention_mask=attention_mask, + ) + else: + return PromptEmbeds( + train_tools.encode_prompts( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob + ) + ) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + + images = torch.stack(image_list) + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) + latents = latents.to(device, dtype=dtype) + + return latents + + def encode_audio(self, audio_data_list): + # audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)} + raise NotImplementedError("Audio encoding not implemented for this model.") + + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == torch.device("cpu"): + self.vae.to(self.device_torch) + latents = latents.to(self.device_torch, dtype=self.torch_dtype) + latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) + + def state_dict(self, vae=True, text_encoder=True, unet=True): + state_dict = OrderedDict() + if vae: + for k, v in self.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + state_dict[new_key] = v + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + for k, v in encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" + state_dict[new_key] = v + else: + for k, v in self.text_encoder.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" + state_dict[new_key] = v + if unet: + for k, v in self.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + state_dict[new_key] = v + return state_dict + + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: + named_params: OrderedDict[str, Parameter] = OrderedDict() + if vae: + for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): + named_params[name] = param + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): + named_params[name] = param + else: + for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): + named_params[name] = param + if unet: + if self.is_flux or self.is_lumina2: + for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): + named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + if self.model_config.ignore_if_contains is not None: + # remove params that contain the ignore_if_contains from named params + for key in list(named_params.keys()): + if any([s in key for s in self.model_config.ignore_if_contains]): + del named_params[key] + if self.model_config.only_if_contains is not None: + # remove params that do not contain the only_if_contains from named params + for key in list(named_params.keys()): + if not any([s in key for s in self.model_config.only_if_contains]): + del named_params[key] + + if refiner: + for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): + named_params[name] = param + + # convert to state dict keys, jsut replace . with _ on keys + if state_dict_keys: + new_named_params = OrderedDict() + for k, v in named_params.items(): + # replace only the first . with an _ + new_key = k.replace('.', '_', 1) + new_named_params[new_key] = v + named_params = new_named_params + + return named_params + + def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')): + + # load the full refiner since we only train unet + if self.model_config.refiner_name_or_path is None: + raise ValueError("Refiner must be specified to save it") + refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') + # load the refiner model + dtype = get_torch_dtype(self.dtype) + model_path = self.model_config._original_refiner_name_or_path + if not os.path.exists(model_path) or os.path.isdir(model_path): + # TODO only load unet?? + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, + dtype=dtype, + device='cpu', + # variant="fp16", + use_safetensors=True, + ) + else: + refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( + model_path, + dtype=dtype, + device='cpu', + torch_dtype=self.torch_dtype, + original_config_file=refiner_config_path, + ) + # replace original unet + refiner.unet = self.refiner_unet + flush() + + diffusers_state_dict = OrderedDict() + for k, v in refiner.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + diffusers_state_dict[new_key] = v + for k, v in refiner.text_encoder_2.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" + diffusers_state_dict[new_key] = v + for k, v in refiner.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + diffusers_state_dict[new_key] = v + + converted_state_dict = get_ldm_state_dict_from_diffusers( + diffusers_state_dict, + 'sdxl_refiner', + device='cpu', + dtype=save_dtype + ) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + version_string = '1' + if self.is_v2: + version_string = '2' + if self.is_xl: + version_string = 'sdxl' + if self.is_ssd: + # overwrite sdxl because both wil be true here + version_string = 'ssd' + if self.is_ssd and self.is_vega: + version_string = 'vega' + # if output file does not end in .safetensors, then it is a directory and we are + # saving in diffusers format + if not output_file.endswith('.safetensors'): + # diffusers + if self.is_flux: + # only save the unet + transformer: FluxTransformer2DModel = unwrap_model(self.unet) + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + elif self.is_lumina2: + # only save the unet + transformer: Lumina2Transformer2DModel = unwrap_model(self.unet) + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + + else: + + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_file, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + else: + save_ldm_model_from_diffusers( + sd=self, + output_file=output_file, + meta=meta, + save_dtype=save_dtype, + sd_version=version_string, + ) + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def prepare_optimizer_params( + self, + unet=False, + text_encoder=False, + text_encoder_lr=None, + unet_lr=None, + refiner_lr=None, + refiner=False, + default_lr=1e-6, + ): + # todo maybe only get locon ones? + # not all items are saved, to make it match, we need to match out save mappings + # and not train anything not mapped. Also add learning rate + version = 'sd1' + if self.is_xl: + version = 'sdxl' + if self.is_v2: + version = 'sd2' + mapping_filename = f"stable_diffusion_{version}.json" + mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) + with open(mapping_path, 'r') as f: + mapping = json.load(f) + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + + trainable_parameters = [] + + # we use state dict to find params + + if unet: + named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) + unet_lr = unet_lr if unet_lr is not None else default_lr + params = [] + if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2: + for param in named_params.values(): + if param.requires_grad: + params.append(param) + else: + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": unet_lr} + trainable_parameters.append(param_data) + print_acc(f"Found {len(params)} trainable parameter in unet") + + if text_encoder: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) + text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": text_encoder_lr} + trainable_parameters.append(param_data) + + print_acc(f"Found {len(params)} trainable parameter in text encoder") + + if refiner: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) + refiner_lr = refiner_lr if refiner_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + diffusers_key = f"refiner_{diffusers_key}" + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": refiner_lr} + trainable_parameters.append(param_data) + + print_acc(f"Found {len(params)} trainable parameter in refiner") + + return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + unet_has_grad = False + + self.device_state = { + **empty_preset, + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + 'requires_grad': unet_has_grad, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + if isinstance(encoder, LlamaModel): + te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad + else: + try: + te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad + except: + te_has_grad = False + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': te_has_grad + }) + else: + if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): + te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + elif isinstance(self.text_encoder, Gemma2Model): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + elif isinstance(self.text_encoder, Qwen2Model): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + elif isinstance(self.text_encoder, LlamaModel): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + else: + te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + 'requires_grad': te_has_grad + } + if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.image_proj_model.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': adapter_device, + 'requires_grad': requires_grad, + } + + if self.refiner_unet is not None: + self.device_state['refiner_unet'] = { + 'training': self.refiner_unet.training, + 'device': self.refiner_unet.device, + 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) + else: + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_(state['text_encoder']['requires_grad']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() + + if self.refiner_unet is not None: + self.refiner_unet.to(state['refiner_unet']['device']) + self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad']) + if state['refiner_unet']['training']: + self.refiner_unet.train() + else: + self.refiner_unet.eval() + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] + if device_state_preset in ['cache_text_encoder']: + active_modules = ['text_encoder'] + if device_state_preset in ['unload']: + active_modules = [] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet'] + + state = copy.deepcopy(empty_preset) + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, + } + + if self.refiner_unet is not None: + state['refiner_unet'] = { + 'training': 'refiner_unet' in training_modules, + 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', + 'requires_grad': 'refiner_unet' in training_modules, + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + } + + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + + self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs) + + def convert_lora_weights_before_save(self, state_dict): + # can be overridden in child classes to convert weights before saving + return state_dict + + def convert_lora_weights_before_load(self, state_dict): + # can be overridden in child classes to convert weights before loading + return state_dict + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + # can be overridden in child classes to condition latents before noise prediction + return latents + + def get_transformer_block_names(self) -> Optional[List[str]]: + # override in child classes to get transformer block names for lora targeting + return None + + def get_quantization_exclude_modules(self) -> Optional[List[str]]: + # override in child classes to keep sensitive modules in full precision when + # quantizing. Returns fnmatch patterns matched against the transformer's module + # names (e.g. "model.x_embedder*"). + return None + + def get_base_model_version(self) -> str: + if self.is_pixart: + return 'pixart' + if self.is_v3: + return 'sd_3' + if self.is_auraflow: + return 'auraflow' + if self.is_flux: + return 'flux.1' + if self.is_lumina2: + return 'lumina2' + if self.is_ssd: + return 'ssd' + if self.is_vega: + return 'vega' + if self.is_xl: + return 'sdxl_1.0' + if self.is_v2: + return 'sd_2.1' + return 'sd_1.5' + + def get_model_to_train(self): + return self.unet + + def scale_loss(self, loss): + # called to get the loss scaler for the model. Can be overridden in child classes + return loss diff --git a/ai-toolkit/toolkit/style.py b/ai-toolkit/toolkit/style.py new file mode 100644 index 0000000000000000000000000000000000000000..26ac33fa710b3286323357abc50b13e9bcda9aec --- /dev/null +++ b/ai-toolkit/toolkit/style.py @@ -0,0 +1,232 @@ +from torch import nn +import torch.nn.functional as F +import torch +from torchvision import models + + +# device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def tensor_size(tensor): + channels = tensor.shape[1] + height = tensor.shape[2] + width = tensor.shape[3] + return channels * height * width + +class ContentLoss(nn.Module): + + def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'): + super(ContentLoss, self).__init__() + self.single_target = single_target + self.device = device + self.loss = None + + def forward(self, stacked_input): + + if self.single_target: + split_size = stacked_input.size()[0] // 2 + pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0) + else: + split_size = stacked_input.size()[0] // 3 + pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0) + + content_size = tensor_size(pred_layer) + + # Define the separate loss function + def separated_loss(y_pred, y_true): + y_pred = y_pred.float() + y_true = y_true.float() + diff = torch.abs(y_pred - y_true) + l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0 + return 2. * l2 / content_size + + # Calculate itemized loss + pred_itemized_loss = separated_loss(pred_layer, target_layer) + # check if is nan + if torch.isnan(pred_itemized_loss).any(): + print('pred_itemized_loss is nan') + + # Calculate the mean of itemized loss + loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True) + self.loss = loss + + return stacked_input + + +def convert_to_gram_matrix(inputs): + inputs = inputs.float() + shape = inputs.size() + batch, filters, height, width = shape[0], shape[1], shape[2], shape[3] + size = height * width * filters + + feats = inputs.view(batch, filters, height * width) + feats_t = feats.transpose(1, 2) + grams_raw = torch.matmul(feats, feats_t) + gram_matrix = grams_raw / size + + return gram_matrix + + +###################################################################### +# Now the style loss module looks almost exactly like the content loss +# module. The style distance is also computed using the mean square +# error between :math:`G_{XL}` and :math:`G_{SL}`. +# + +class StyleLoss(nn.Module): + + def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'): + super(StyleLoss, self).__init__() + self.single_target = single_target + self.device = device + + def forward(self, stacked_input): + input_dtype = stacked_input.dtype + stacked_input = stacked_input.float() + if self.single_target: + split_size = stacked_input.size()[0] // 2 + preds, style_target = torch.split(stacked_input, split_size, dim=0) + else: + split_size = stacked_input.size()[0] // 3 + preds, style_target, _ = torch.split(stacked_input, split_size, dim=0) + + def separated_loss(y_pred, y_true): + gram_size = y_true.size(1) * y_true.size(2) + sum_axis = (1, 2) + diff = torch.abs(y_pred - y_true) + raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True) + return raw_loss / gram_size + + target_grams = convert_to_gram_matrix(style_target) + pred_grams = convert_to_gram_matrix(preds) + itemized_loss = separated_loss(pred_grams, target_grams) + # check if is nan + if torch.isnan(itemized_loss).any(): + print('itemized_loss is nan') + # reshape itemized loss to be (batch, 1, 1, 1) + itemized_loss = torch.unsqueeze(itemized_loss, dim=1) + # gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2]) + loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True) + self.loss = loss.to(input_dtype).float() + return stacked_input.to(input_dtype) + + +# create a module to normalize input image so we can easily put it in a +# ``nn.Sequential`` +class Normalization(nn.Module): + def __init__(self, device, dtype=torch.float32): + super(Normalization, self).__init__() + mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.dtype = dtype + # .view the mean and std to make them [C x 1 x 1] so that they can + # directly work with image Tensor of shape [B x C x H x W]. + # B is batch size. C is number of channels. H is height and W is width. + self.mean = torch.tensor(mean).view(-1, 1, 1) + self.std = torch.tensor(std).view(-1, 1, 1) + + def forward(self, stacked_input): + # cast to float 32 if not already # only necessary when processing gram matrix + # if stacked_input.dtype != torch.float32: + # stacked_input = stacked_input.float() + # remove alpha channel if it exists + if stacked_input.shape[1] == 4: + stacked_input = stacked_input[:, :3, :, :] + # normalize to min and max of 0 - 1 + in_min = torch.min(stacked_input) + in_max = torch.max(stacked_input) + # norm_stacked_input = (stacked_input - in_min) / (in_max - in_min) + # return (norm_stacked_input - self.mean) / self.std + return ((stacked_input - self.mean) / self.std).to(self.dtype) + + +class OutputLayer(nn.Module): + def __init__(self, name='output_layer'): + super(OutputLayer, self).__init__() + self.name = name + self.tensor = None + + def forward(self, stacked_input): + self.tensor = stacked_input + return stacked_input + + +def get_style_model_and_losses( + single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code + device='cuda' if torch.cuda.is_available() else 'cpu', + output_layer_name=None, + dtype=torch.float32 +): + # content_layers = ['conv_4'] + # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] + content_layers = ['conv2_2', 'conv3_2', 'conv4_2'] + style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] + cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval() + # set all weights in the model to our dtype + # for layer in cnn.children(): + # layer.to(dtype=dtype) + + # normalization module + normalization = Normalization(device, dtype=dtype).to(device) + + # just in order to have an iterable access to or list of content/style + # losses + content_losses = [] + style_losses = [] + + # assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential`` + # to put in modules that are supposed to be activated sequentially + model = nn.Sequential(normalization) + + i = 0 # increment every time we see a conv + block = 1 + children = list(cnn.children()) + + output_layer = None + + for layer in children: + if isinstance(layer, nn.Conv2d): + i += 1 + name = f'conv{block}_{i}_raw' + elif isinstance(layer, nn.ReLU): + # name = 'relu_{}'.format(i) + name = f'conv{block}_{i}' # target this + # The in-place version doesn't play very nicely with the ``ContentLoss`` + # and ``StyleLoss`` we insert below. So we replace with out-of-place + # ones here. + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + name = 'pool_{}'.format(i) + block += 1 + i = 0 + elif isinstance(layer, nn.BatchNorm2d): + name = 'bn_{}'.format(i) + else: + raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) + + model.add_module(name, layer) + + if name in content_layers: + # add content loss: + content_loss = ContentLoss(single_target=single_target, device=device) + model.add_module("content_loss_{}_{}".format(block, i), content_loss) + content_losses.append(content_loss) + + if name in style_layers: + # add style loss: + style_loss = StyleLoss(single_target=single_target, device=device) + model.add_module("style_loss_{}_{}".format(block, i), style_loss) + style_losses.append(style_loss) + + if output_layer_name is not None and name == output_layer_name: + output_layer = OutputLayer(name) + model.add_module("output_layer_{}_{}".format(block, i), output_layer) + + # now we trim off the layers after the last content and style losses + for i in range(len(model) - 1, -1, -1): + if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer): + break + + model = model[:(i + 1)] + model.to(dtype=dtype) + + return model, style_losses, content_losses, output_layer diff --git a/ai-toolkit/toolkit/timer.py b/ai-toolkit/toolkit/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..e849ba5faa1404f776063512fbd3107b9e692bd4 --- /dev/null +++ b/ai-toolkit/toolkit/timer.py @@ -0,0 +1,81 @@ +import time +from collections import OrderedDict, deque +import sys +import os + +# check if is ui process will have IS_AI_TOOLKIT_UI in env +is_ui = os.environ.get("IS_AI_TOOLKIT_UI", "0") == "1" + +class Timer: + def __init__(self, name='Timer', max_buffer=10): + self.name = name + self.max_buffer = max_buffer + self.timers = OrderedDict() + self.active_timers = {} + self.current_timer = None # Used for the context manager functionality + self._after_print_hooks = [] + + def start(self, timer_name): + if timer_name not in self.timers: + self.timers[timer_name] = deque(maxlen=self.max_buffer) + self.active_timers[timer_name] = time.time() + + def cancel(self, timer_name): + """Cancel an active timer.""" + if timer_name in self.active_timers: + del self.active_timers[timer_name] + + def stop(self, timer_name): + if timer_name not in self.active_timers: + raise ValueError(f"Timer '{timer_name}' was not started!") + + elapsed_time = time.time() - self.active_timers[timer_name] + self.timers[timer_name].append(elapsed_time) + + # Clean up active timers + del self.active_timers[timer_name] + + # Check if this timer's buffer exceeds max_buffer and remove the oldest if it does + if len(self.timers[timer_name]) > self.max_buffer: + self.timers[timer_name].popleft() + + def add_after_print_hook(self, hook): + self._after_print_hooks.append(hook) + + def print(self): + if not is_ui: + print(f"\nTimer '{self.name}':") + timing_dict = {} + # sort by longest at top + for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True): + avg_time = sum(timings) / len(timings) + + if not is_ui: + print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}") + timing_dict[timer_name] = avg_time + + for hook in self._after_print_hooks: + hook(timing_dict) + if not is_ui: + print('') + + def reset(self): + self.timers.clear() + self.active_timers.clear() + + def __call__(self, timer_name): + """Enable the use of the Timer class as a context manager.""" + self.current_timer = timer_name + self.start(timer_name) + return self + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + # No exceptions, stop the timer normally + self.stop(self.current_timer) + else: + # There was an exception, cancel the timer + self.cancel(self.current_timer) diff --git a/ai-toolkit/toolkit/timestep_weighing/__init__.py b/ai-toolkit/toolkit/timestep_weighing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ai-toolkit/toolkit/timestep_weighing/default_weighing_scheme.py b/ai-toolkit/toolkit/timestep_weighing/default_weighing_scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..7269974429e3b1f8ba27aead7f0d75075a8ef81a --- /dev/null +++ b/ai-toolkit/toolkit/timestep_weighing/default_weighing_scheme.py @@ -0,0 +1,1004 @@ +# these weights were calculated using flex.1-alpha. A similar weighing scheme has been seen with other flowmatch models as well. + +default_weighing_scheme = [ + 0.905706524848938, + 0.9097874164581299, + 0.9251009821891785, + 0.9399133920669556, + 0.9497355818748474, + 0.962195873260498, + 0.9691638946533203, + 0.9927358627319336, + 1.01659095287323, + 1.0392684936523438, + 1.0429067611694336, + 1.0677919387817383, + 1.092896580696106, + 1.1149251461029053, + 1.1015851497650146, + 1.1209120750427246, + 1.1399472951889038, + 1.1559219360351562, + 1.143755555152893, + 1.160578727722168, + 1.1761468648910522, + 1.1895899772644043, + 1.1867371797561646, + 1.1993824243545532, + 1.2113513946533203, + 1.2201216220855713, + 1.221794605255127, + 1.2329251766204834, + 1.2426395416259766, + 1.251118540763855, + 1.2562459707260132, + 1.2668883800506592, + 1.2760146856307983, + 1.2822540998458862, + 1.2892439365386963, + 1.2972776889801025, + 1.3042182922363281, + 1.3104143142700195, + 1.3190876245498657, + 1.3253265619277954, + 1.3295484781265259, + 1.3324649333953857, + 1.3436106443405151, + 1.3482736349105835, + 1.351183533668518, + 1.3556060791015625, + 1.359933614730835, + 1.3625402450561523, + 1.3641390800476074, + 1.3723753690719604, + 1.3748365640640259, + 1.3776681423187256, + 1.3802807331085205, + 1.3848408460617065, + 1.3877429962158203, + 1.390969157218933, + 1.3927756547927856, + 1.4031920433044434, + 1.4064390659332275, + 1.4097232818603516, + 1.4132285118103027, + 1.4208053350448608, + 1.4254504442214966, + 1.4294012784957886, + 1.4323071241378784, + 1.44380521774292, + 1.4465110301971436, + 1.4490883350372314, + 1.451446533203125, + 1.458960771560669, + 1.460111379623413, + 1.4603426456451416, + 1.4595434665679932, + 1.4694839715957642, + 1.470526099205017, + 1.4705318212509155, + 1.4697961807250977, + 1.477265477180481, + 1.4771337509155273, + 1.4768965244293213, + 1.4760032892227173, + 1.4871419668197632, + 1.4877341985702515, + 1.488408088684082, + 1.489357829093933, + 1.4884581565856934, + 1.4882304668426514, + 1.4878745079040527, + 1.4949846267700195, + 1.4957915544509888, + 1.4951081275939941, + 1.4950779676437378, + 1.4995663166046143, + 1.4994632005691528, + 1.4986882209777832, + 1.4976568222045898, + 1.5038518905639648, + 1.503743052482605, + 1.502982258796692, + 1.502566933631897, + 1.5068002939224243, + 1.506575584411621, + 1.5067178010940552, + 1.5050104856491089, + 1.508437991142273, + 1.5073161125183105, + 1.506223440170288, + 1.504751205444336, + 1.512485384941101, + 1.5121595859527588, + 1.5116060972213745, + 1.5102864503860474, + 1.5156408548355103, + 1.5157924890518188, + 1.5142825841903687, + 1.5141233205795288, + 1.5191627740859985, + 1.518998622894287, + 1.5177574157714844, + 1.516713261604309, + 1.5186537504196167, + 1.5179635286331177, + 1.5159885883331299, + 1.5150446891784668, + 1.5226575136184692, + 1.5215232372283936, + 1.5201103687286377, + 1.5225480794906616, + 1.5215368270874023, + 1.520209789276123, + 1.5184354782104492, + 1.5220975875854492, + 1.5209708213806152, + 1.519667625427246, + 1.5177332162857056, + 1.5203157663345337, + 1.519271731376648, + 1.5175830125808716, + 1.5161852836608887, + 1.517741322517395, + 1.5163099765777588, + 1.5150357484817505, + 1.513055682182312, + 1.5154107809066772, + 1.513993501663208, + 1.5126358270645142, + 1.510875940322876, + 1.5125701427459717, + 1.510590672492981, + 1.5086392164230347, + 1.5068501234054565, + 1.5094420909881592, + 1.5080277919769287, + 1.5060293674468994, + 1.5040230751037598, + 1.504892110824585, + 1.5031654834747314, + 1.5013214349746704, + 1.499170184135437, + 1.500420331954956, + 1.4979915618896484, + 1.4959659576416016, + 1.494127631187439, + 1.4961415529251099, + 1.494597315788269, + 1.4926577806472778, + 1.4904958009719849, + 1.4913445711135864, + 1.489889144897461, + 1.4880430698394775, + 1.4873780012130737, + 1.485144019126892, + 1.4830609560012817, + 1.48123037815094, + 1.483134150505066, + 1.4808504581451416, + 1.4793643951416016, + 1.4773707389831543, + 1.4774049520492554, + 1.4752962589263916, + 1.4733281135559082, + 1.4712235927581787, + 1.470260500907898, + 1.4684780836105347, + 1.4657323360443115, + 1.4639259576797485, + 1.4644291400909424, + 1.4621423482894897, + 1.4604870080947876, + 1.4585931301116943, + 1.458650827407837, + 1.456871747970581, + 1.4548134803771973, + 1.4528228044509888, + 1.4519706964492798, + 1.4501826763153076, + 1.448083519935608, + 1.446316123008728, + 1.4454604387283325, + 1.4432463645935059, + 1.4412567615509033, + 1.4390138387680054, + 1.4388699531555176, + 1.437005877494812, + 1.4348008632659912, + 1.4327361583709717, + 1.4320861101150513, + 1.4304683208465576, + 1.4287059307098389, + 1.4264065027236938, + 1.424929141998291, + 1.4226492643356323, + 1.4205398559570312, + 1.4207192659378052, + 1.4187930822372437, + 1.4170515537261963, + 1.415097713470459, + 1.4142191410064697, + 1.412431001663208, + 1.4104442596435547, + 1.4082318544387817, + 1.4076896905899048, + 1.4057838916778564, + 1.4034852981567383, + 1.4016139507293701, + 1.4003900289535522, + 1.3983466625213623, + 1.3962912559509277, + 1.3940908908843994, + 1.3933204412460327, + 1.3910188674926758, + 1.3888089656829834, + 1.3865182399749756, + 1.3853325843811035, + 1.3832392692565918, + 1.3812776803970337, + 1.3789170980453491, + 1.3797601461410522, + 1.3777130842208862, + 1.3756908178329468, + 1.3735847473144531, + 1.3711339235305786, + 1.3690725564956665, + 1.366849660873413, + 1.364676833152771, + 1.364690899848938, + 1.362541675567627, + 1.3608990907669067, + 1.358725666999817, + 1.3566731214523315, + 1.3549315929412842, + 1.3522361516952515, + 1.3527694940567017, + 1.3506978750228882, + 1.3486748933792114, + 1.3464853763580322, + 1.344880223274231, + 1.3429863452911377, + 1.3410741090774536, + 1.3392194509506226, + 1.3374866247177124, + 1.3356633186340332, + 1.333699345588684, + 1.3316830396652222, + 1.3293700218200684, + 1.327602744102478, + 1.325606107711792, + 1.323636770248413, + 1.322487473487854, + 1.3201935291290283, + 1.3185865879058838, + 1.3163570165634155, + 1.315348505973816, + 1.3135069608688354, + 1.3115581274032593, + 1.309351921081543, + 1.3080940246582031, + 1.306084394454956, + 1.3041918277740479, + 1.3022172451019287, + 1.30052649974823, + 1.2984906435012817, + 1.296433925628662, + 1.294618844985962, + 1.2924813032150269, + 1.290629267692566, + 1.288609504699707, + 1.286437749862671, + 1.2852808237075806, + 1.2831010818481445, + 1.281022071838379, + 1.2789161205291748, + 1.2785669565200806, + 1.2766060829162598, + 1.274585485458374, + 1.2728400230407715, + 1.2709832191467285, + 1.2691309452056885, + 1.2671318054199219, + 1.265442132949829, + 1.2635501623153687, + 1.2614946365356445, + 1.2593908309936523, + 1.257880449295044, + 1.2560313940048218, + 1.254082441329956, + 1.2522804737091064, + 1.2505096197128296, + 1.2482692003250122, + 1.2462091445922852, + 1.2445822954177856, + 1.2432236671447754, + 1.2414650917053223, + 1.2396503686904907, + 1.2376699447631836, + 1.2357380390167236, + 1.2339240312576294, + 1.2320566177368164, + 1.2299892902374268, + 1.2286840677261353, + 1.226925015449524, + 1.2250070571899414, + 1.223126769065857, + 1.2215166091918945, + 1.2196996212005615, + 1.2178195714950562, + 1.2158279418945312, + 1.2140803337097168, + 1.2121261358261108, + 1.210100769996643, + 1.2083507776260376, + 1.2063525915145874, + 1.2046012878417969, + 1.2027149200439453, + 1.201154112815857, + 1.1992254257202148, + 1.1971834897994995, + 1.1951549053192139, + 1.1935709714889526, + 1.191764235496521, + 1.1898930072784424, + 1.187896728515625, + 1.186535120010376, + 1.184600591659546, + 1.1826894283294678, + 1.1809728145599365, + 1.1789331436157227, + 1.1774684190750122, + 1.1756458282470703, + 1.1736308336257935, + 1.1719911098480225, + 1.1701127290725708, + 1.1681468486785889, + 1.165921926498413, + 1.16463041305542, + 1.1627451181411743, + 1.1608567237854004, + 1.1590938568115234, + 1.1575335264205933, + 1.1555901765823364, + 1.1538552045822144, + 1.1518657207489014, + 1.14994215965271, + 1.1481153964996338, + 1.14644455909729, + 1.1444120407104492, + 1.1428122520446777, + 1.1410313844680786, + 1.1391836404800415, + 1.137330412864685, + 1.135434627532959, + 1.1336791515350342, + 1.131978154182434, + 1.1300874948501587, + 1.128359317779541, + 1.1264809370040894, + 1.1248611211776733, + 1.122762680053711, + 1.1209162473678589, + 1.1190710067749023, + 1.1172044277191162, + 1.1158984899520874, + 1.1140459775924683, + 1.1124012470245361, + 1.110682487487793, + 1.1087219715118408, + 1.106826901435852, + 1.1050584316253662, + 1.1034021377563477, + 1.1011031866073608, + 1.0996853113174438, + 1.0978131294250488, + 1.0963127613067627, + 1.0944904088974, + 1.0927494764328003, + 1.0910944938659668, + 1.0892736911773682, + 1.0878331661224365, + 1.0860958099365234, + 1.0842169523239136, + 1.0826303958892822, + 1.0806686878204346, + 1.078961730003357, + 1.0773676633834839, + 1.0755786895751953, + 1.073934555053711, + 1.0721861124038696, + 1.0704376697540283, + 1.0689181089401245, + 1.067183256149292, + 1.0654473304748535, + 1.0637754201889038, + 1.0620981454849243, + 1.0604465007781982, + 1.0587077140808105, + 1.0570865869522095, + 1.0553107261657715, + 1.053688883781433, + 1.0520380735397339, + 1.0502020120620728, + 1.048741340637207, + 1.046962022781372, + 1.0453627109527588, + 1.0439050197601318, + 1.041886806488037, + 1.0405514240264893, + 1.0387938022613525, + 1.0370451211929321, + 1.035706877708435, + 1.03403902053833, + 1.0325669050216675, + 1.0308712720870972, + 1.0291008949279785, + 1.0275760889053345, + 1.0258709192276, + 1.024153470993042, + 1.022807240486145, + 1.0211118459701538, + 1.019489049911499, + 1.0178107023239136, + 1.0159832239151, + 1.0143824815750122, + 1.0128840208053589, + 1.0111985206604004, + 1.0098742246627808, + 1.0081874132156372, + 1.0064918994903564, + 1.0050266981124878, + 1.0036821365356445, + 1.0018991231918335, + 1.0004172325134277, + 0.9988566637039185, + 0.9969817399978638, + 0.9954714179039001, + 0.9939242005348206, + 0.9923979640007019, + 0.9910774230957031, + 0.9894015789031982, + 0.9880895614624023, + 0.9861252903938293, + 0.9846389889717102, + 0.9831112027168274, + 0.9815076589584351, + 0.9799305200576782, + 0.9784950017929077, + 0.976923942565918, + 0.975475549697876, + 0.9737277626991272, + 0.9722781181335449, + 0.9707712531089783, + 0.9693742394447327, + 0.9677569270133972, + 0.9663806557655334, + 0.9648120999336243, + 0.963326096534729, + 0.9619874358177185, + 0.9605197906494141, + 0.9590029120445251, + 0.9575618505477905, + 0.9558634757995605, + 0.9542866945266724, + 0.9530059099197388, + 0.9513764977455139, + 0.9499674439430237, + 0.948621392250061, + 0.947046160697937, + 0.945502519607544, + 0.9441988468170166, + 0.9427464604377747, + 0.9413387179374695, + 0.9397821426391602, + 0.9385508894920349, + 0.9372508525848389, + 0.9356773495674133, + 0.9340954422950745, + 0.9325379133224487, + 0.9311357140541077, + 0.9296550154685974, + 0.9283716082572937, + 0.9268398880958557, + 0.9254037141799927, + 0.9239259362220764, + 0.9225856065750122, + 0.921108603477478, + 0.9197893142700195, + 0.9185012578964233, + 0.9169778823852539, + 0.9154301881790161, + 0.9140625, + 0.9127756357192993, + 0.9113842844963074, + 0.9101965427398682, + 0.9088224172592163, + 0.9074375629425049, + 0.9061430096626282, + 0.9046499133110046, + 0.9033547043800354, + 0.9018712639808655, + 0.9006990790367126, + 0.8993589878082275, + 0.8980291485786438, + 0.8965833187103271, + 0.8953617811203003, + 0.8940249681472778, + 0.8928234577178955, + 0.8914735913276672, + 0.8900470733642578, + 0.8885773420333862, + 0.887448251247406, + 0.8860753178596497, + 0.8848751783370972, + 0.8835704326629639, + 0.8822427988052368, + 0.8808343410491943, + 0.8794860243797302, + 0.8782272338867188, + 0.876940131187439, + 0.8755697011947632, + 0.8743593096733093, + 0.8731096982955933, + 0.8717764019966125, + 0.870373547077179, + 0.869137704372406, + 0.8679963946342468, + 0.8665465116500854, + 0.8653771281242371, + 0.8643192052841187, + 0.8630129098892212, + 0.8618021011352539, + 0.8606610894203186, + 0.8596193194389343, + 0.8584977984428406, + 0.8571111559867859, + 0.8558118343353271, + 0.854767382144928, + 0.8535858392715454, + 0.8525562882423401, + 0.851208508014679, + 0.8500548601150513, + 0.8489854335784912, + 0.8476380109786987, + 0.8465084433555603, + 0.8454263806343079, + 0.8440982699394226, + 0.8429536819458008, + 0.8419493436813354, + 0.8406177759170532, + 0.8395005464553833, + 0.83843994140625, + 0.8372390866279602, + 0.836262583732605, + 0.8351759910583496, + 0.8340833187103271, + 0.8330100178718567, + 0.8318305611610413, + 0.8307360410690308, + 0.8296796083450317, + 0.8287205696105957, + 0.8275678753852844, + 0.8264811038970947, + 0.8253570795059204, + 0.8243551254272461, + 0.8232539296150208, + 0.822137176990509, + 0.8212800025939941, + 0.8199703097343445, + 0.8190608024597168, + 0.8179953098297119, + 0.8167867064476013, + 0.8158150315284729, + 0.8149182200431824, + 0.8140754699707031, + 0.8131433129310608, + 0.8118599057197571, + 0.8109708428382874, + 0.8099024891853333, + 0.8090004324913025, + 0.8079776763916016, + 0.807029664516449, + 0.8058684468269348, + 0.8049055337905884, + 0.8039948344230652, + 0.803061306476593, + 0.8021382689476013, + 0.8012913465499878, + 0.8002091646194458, + 0.7992268204689026, + 0.7981467247009277, + 0.7973214983940125, + 0.7964017987251282, + 0.7954541444778442, + 0.7945792078971863, + 0.7938122153282166, + 0.7926003932952881, + 0.7917800545692444, + 0.7908596396446228, + 0.7899304628372192, + 0.7890149354934692, + 0.7882192730903625, + 0.7870058417320251, + 0.7863731980323792, + 0.7852027416229248, + 0.7844488024711609, + 0.783501386642456, + 0.7827003598213196, + 0.7819803357124329, + 0.7808201909065247, + 0.7800688147544861, + 0.7791293263435364, + 0.7784658670425415, + 0.7775732278823853, + 0.7768633961677551, + 0.7760342359542847, + 0.775243878364563, + 0.7743030786514282, + 0.7735926508903503, + 0.7724748849868774, + 0.7718163728713989, + 0.77097487449646, + 0.7702510356903076, + 0.7693900465965271, + 0.7687169313430786, + 0.7678922414779663, + 0.7672128081321716, + 0.7663589715957642, + 0.7657037377357483, + 0.7647771239280701, + 0.7640203237533569, + 0.7633466720581055, + 0.7625623941421509, + 0.7617509961128235, + 0.7610896229743958, + 0.760379433631897, + 0.7596492767333984, + 0.7588953971862793, + 0.7581916451454163, + 0.7573999166488647, + 0.7568274736404419, + 0.756077229976654, + 0.7554765939712524, + 0.7546539306640625, + 0.7539674043655396, + 0.753139853477478, + 0.7525543570518494, + 0.7519160509109497, + 0.7513154149055481, + 0.7505142688751221, + 0.7497125864028931, + 0.74923175573349, + 0.7484207153320312, + 0.7479155659675598, + 0.7473617792129517, + 0.7468436360359192, + 0.7462318539619446, + 0.7456430792808533, + 0.7447810769081116, + 0.7442206144332886, + 0.7435954809188843, + 0.7431489825248718, + 0.7422271370887756, + 0.7418114542961121, + 0.7412892580032349, + 0.740713357925415, + 0.7401546239852905, + 0.7396021485328674, + 0.7390599846839905, + 0.73844313621521, + 0.7377904653549194, + 0.737305223941803, + 0.7368288636207581, + 0.7363747358322144, + 0.7358483076095581, + 0.7354381680488586, + 0.7348212003707886, + 0.7343763709068298, + 0.7336553335189819, + 0.7332231402397156, + 0.73262619972229, + 0.7321929931640625, + 0.7315752506256104, + 0.7312256693840027, + 0.7306149005889893, + 0.7302426695823669, + 0.7299467325210571, + 0.7294563055038452, + 0.728706419467926, + 0.7283353209495544, + 0.7279900312423706, + 0.7276231646537781, + 0.7273217439651489, + 0.7269001007080078, + 0.7265130877494812, + 0.7261000871658325, + 0.7257733345031738, + 0.725188672542572, + 0.724976658821106, + 0.7242119908332825, + 0.7238465547561646, + 0.7236427664756775, + 0.7232236266136169, + 0.7227877974510193, + 0.7226144075393677, + 0.7221262454986572, + 0.7218216061592102, + 0.7215451002120972, + 0.7213869094848633, + 0.7209206819534302, + 0.7207257747650146, + 0.7203994989395142, + 0.7200448513031006, + 0.7197679281234741, + 0.7195186018943787, + 0.7190226912498474, + 0.7188836932182312, + 0.7186117768287659, + 0.7185105681419373, + 0.718199610710144, + 0.7180152535438538, + 0.7175536155700684, + 0.7173341512680054, + 0.7171205878257751, + 0.7168837189674377, + 0.7163654565811157, + 0.7162774801254272, + 0.7161651253700256, + 0.7160663604736328, + 0.7159175872802734, + 0.7157440185546875, + 0.7154026031494141, + 0.7153436541557312, + 0.715220034122467, + 0.7150475978851318, + 0.7150062322616577, + 0.7149052619934082, + 0.7147804498672485, + 0.7147180438041687, + 0.7146442532539368, + 0.7143230438232422, + 0.7142894268035889, + 0.7143252491950989, + 0.7141361236572266, + 0.7140751481056213, + 0.7138771414756775, + 0.7138750553131104, + 0.7138450145721436, + 0.7138748168945312, + 0.7137607336044312, + 0.7137340903282166, + 0.7137055993080139, + 0.7136792540550232, + 0.7135677337646484, + 0.7134214639663696, + 0.7135778069496155, + 0.7136402130126953, + 0.713737964630127, + 0.7136131525039673, + 0.7135958075523376, + 0.713367760181427, + 0.7136869430541992, + 0.7137601971626282, + 0.7137682437896729, + 0.7137079834938049, + 0.7138195633888245, + 0.713512122631073, + 0.7136629819869995, + 0.7137271761894226, + 0.7138593792915344, + 0.714098334312439, + 0.714293360710144, + 0.7142848372459412, + 0.7144456505775452, + 0.714730978012085, + 0.7147268652915955, + 0.7149925231933594, + 0.7151434421539307, + 0.7151704430580139, + 0.7152837514877319, + 0.7152866125106812, + 0.7155521512031555, + 0.7158286571502686, + 0.7161067724227905, + 0.7161760926246643, + 0.716302752494812, + 0.7165276408195496, + 0.716739296913147, + 0.7168811559677124, + 0.7171459794044495, + 0.7174181342124939, + 0.7176154255867004, + 0.7177509069442749, + 0.7180988192558289, + 0.718157172203064, + 0.7184935212135315, + 0.7185874581336975, + 0.7188709378242493, + 0.7191057801246643, + 0.7192631959915161, + 0.7195265293121338, + 0.7197085618972778, + 0.720137357711792, + 0.7203015089035034, + 0.7204228639602661, + 0.720592737197876, + 0.7209603786468506, + 0.7212156057357788, + 0.7214911580085754, + 0.72171950340271, + 0.7221818566322327, + 0.7225285768508911, + 0.7227242588996887, + 0.7229769825935364, + 0.7232232689857483, + 0.7235181927680969, + 0.7238690853118896, + 0.7240993976593018, + 0.7244613170623779, + 0.7245984673500061, + 0.7250016331672668, + 0.7251067161560059, + 0.7255734205245972, + 0.7258168458938599, + 0.7260231375694275, + 0.7262008786201477, + 0.7266826629638672, + 0.7266356945037842, + 0.7272213697433472, + 0.7274652719497681, + 0.7279736399650574, + 0.7281084656715393, + 0.7283049821853638, + 0.7285397052764893, + 0.7289696931838989, + 0.7293713688850403, + 0.72969651222229, + 0.7298030853271484, + 0.7300551533699036, + 0.7303762435913086, + 0.7306288480758667, + 0.7307636141777039, + 0.7312272787094116, + 0.7314295172691345, + 0.7316324710845947, + 0.7314282655715942, + 0.7316331267356873, + 0.7319273352622986, + 0.732114315032959, + 0.732216477394104, + 0.7322894930839539, + 0.7325413227081299, + 0.7325369715690613, + 0.7325401902198792, + 0.7323561906814575, + 0.7322503924369812, + 0.7322250008583069, + 0.7320146560668945, + 0.7321474552154541, + 0.732187032699585, + 0.7320796251296997, + 0.7315171360969543, + 0.7313243746757507, + 0.7310792803764343, + 0.7308080196380615, + 0.7306304574012756, + 0.7296295166015625, + 0.7289745807647705, + 0.7288649082183838, + 0.7281786799430847, + 0.7277448773384094, + 0.7272322177886963, + 0.7265949845314026, + 0.725848376750946, + 0.7252488136291504, + 0.7245422601699829, + 0.723800003528595, + 0.7228619456291199, + 0.7218728065490723, + 0.720892071723938, + 0.7198144793510437, + 0.7186352610588074, + 0.717454731464386, + 0.7162171602249146, + 0.7149246335029602, + 0.7136655449867249, + 0.7121627926826477, + 0.7108365297317505, + 0.7092090249061584, + 0.7076245546340942, + 0.7060236930847168, + 0.704273521900177, + 0.7023670673370361, + 0.7007937431335449, + 0.698858916759491, + 0.696875810623169, + 0.69483882188797, + 0.6926799416542053, + 0.6907360553741455, + 0.6884570121765137, + 0.68642657995224, + 0.6837813258171082, + 0.6816286444664001, + 0.6790634989738464, + 0.6767467260360718, + 0.6742716431617737, + 0.6715688705444336, + 0.6689924001693726, + 0.6662940382957458, + 0.6634176969528198, + 0.6603904962539673, + 0.6574862599372864, + 0.6544369459152222, + 0.651530385017395, + 0.6485568284988403, + 0.6453983187675476, + 0.6423068046569824, + 0.6392328143119812, + 0.6360040307044983, + 0.6325976252555847, + 0.6291093826293945, + 0.6256147623062134, + 0.6223041415214539, + 0.6188730001449585, + 0.615329921245575, + 0.6118036508560181, + 0.6081951260566711, + 0.604539155960083, + 0.6009404063224792, + 0.5972440242767334, + 0.5937116146087646, + 0.5897051692008972, + 0.5859677195549011, + 0.5821571946144104, + 0.578381359577179, + 0.5747998952865601, + 0.5709896683692932, + 0.5671953558921814, + 0.5633583068847656, + 0.5594204068183899, + 0.5555382370948792, + 0.5519280433654785, + 0.5482934713363647, + 0.544551432132721, + 0.5410515666007996, + 0.5374910831451416, + 0.5340041518211365, + 0.5304144024848938, + 0.5269584655761719, + 0.5235306620597839, + 0.520039975643158, + 0.516674280166626, + 0.513296902179718, + 0.5098193883895874, + 0.5064578652381897, + 0.5030517578125, + 0.4997297525405884, + 0.4967145025730133, + 0.49335765838623047, + 0.4902186989784241, + 0.486634224653244, + 0.48311659693717957, + 0.4792158007621765, + 0.4755136966705322, + 0.4720709025859833, + 0.4689248502254486, + 0.4660993814468384, + 0.46355342864990234, + 0.46058982610702515, + 0.45763304829597473, + 0.45535609126091003, + 0.45405313372612, + 0.45241352915763855, + 0.45207348465919495, + 0.45095735788345337, + 0.45052871108055115, + 0.449806272983551, + 0.4484655559062958, + 0.44648951292037964, + 0.44580715894699097, + 0.4447800815105438, + 0.4453802704811096, + 0.4472601115703583 +] \ No newline at end of file diff --git a/ai-toolkit/toolkit/train_pipelines.py b/ai-toolkit/toolkit/train_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cc623cd55e802bcad1de41cd90be6a57d2743a --- /dev/null +++ b/ai-toolkit/toolkit/train_pipelines.py @@ -0,0 +1,316 @@ +from typing import Optional, Tuple, Callable, Dict, Any, Union, List + +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.pipelines import CustomStableDiffusionXLPipeline + + +class TransferStableDiffusionXLPipeline(CustomStableDiffusionXLPipeline): + def transfer_diffuse( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + target_unet: Optional[torch.nn.Module] = None, + pre_condition_callback = None, + each_step_callback = None, + network: Optional[LoRASpecialNetwork] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + conditioned_noise_pred, conditioned_latent_model_input = pre_condition_callback( + noise_pred.clone().detach(), + latent_model_input.clone().detach(), + ) + + # start grad + with torch.enable_grad(): + with network: + assert network.is_active + noise_train_pred = target_unet( + conditioned_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + each_step_callback(conditioned_noise_pred, noise_train_pred) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + diff --git a/ai-toolkit/toolkit/train_tools.py b/ai-toolkit/toolkit/train_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..78e2183c3725b43fab3a437442de855c044771af --- /dev/null +++ b/ai-toolkit/toolkit/train_tools.py @@ -0,0 +1,765 @@ +import argparse +import hashlib +import json +import os +import time +from typing import TYPE_CHECKING, Union, List +import sys + + +from diffusers import ( + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler +) +import torch +import re +from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel + +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL +TEXT_ENCODER_2_PROJECTION_DIM = 1280 +UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 + + +def get_torch_dtype(dtype_str): + # if it is a torch dtype, return it + if isinstance(dtype_str, torch.dtype): + return dtype_str + if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32": + return torch.float + if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16": + return torch.float16 + if dtype_str == "bf16" or dtype_str == "bfloat16": + return torch.bfloat16 + if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8": + return torch.float8_e4m3fn + return dtype_str + + +def replace_filewords_prompt(prompt, args: argparse.Namespace): + # if name_replace attr in args (may not be) + if hasattr(args, "name_replace") and args.name_replace is not None: + # replace [name] to args.name_replace + prompt = prompt.replace("[name]", args.name_replace) + if hasattr(args, "prepend") and args.prepend is not None: + # prepend to every item in prompt file + prompt = args.prepend + ' ' + prompt + if hasattr(args, "append") and args.append is not None: + # append to every item in prompt file + prompt = prompt + ' ' + args.append + return prompt + + +def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace): + # if name_replace attr in args (may not be) + if hasattr(args, "name_replace") and args.name_replace is not None: + if not len(dataset_group.image_data) > 0: + # throw error + raise ValueError("dataset_group.image_data is empty") + for key in dataset_group.image_data: + dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace( + "[name]", args.name_replace) + + return dataset_group + + +def get_seeds_from_latents(latents): + # latents shape = (batch_size, 4, height, width) + # for speed we only use 8x8 slice of the first channel + seeds = [] + + # split batch up + for i in range(latents.shape[0]): + # use only first channel, multiply by 255 and convert to int + tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width) + # slice 8x8 + tensor = tensor[:8, :8] + # clip to 0-255 + tensor = torch.clamp(tensor, 0, 255) + # convert to 8bit int + tensor = tensor.to(torch.uint8) + # convert to bytes + tensor_bytes = tensor.cpu().numpy().tobytes() + # hash + hash_object = hashlib.sha256(tensor_bytes) + # get hex + hex_dig = hash_object.hexdigest() + # convert to int + seed = int(hex_dig, 16) % (2 ** 32) + # append + seeds.append(seed) + return seeds + + +def get_noise_from_latents(latents): + seed_list = get_seeds_from_latents(latents) + noise = [] + for seed in seed_list: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + noise.append(torch.randn_like(latents[0])) + return torch.stack(noise) + + +# mix 0 is completely noise mean, mix 1 is completely target mean + +def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None): + dim = dim or (1, 2, 3) + # reduce mean of noise on dim 2, 3, keeping 0 and 1 intact + noise_mean = noise.mean(dim=dim, keepdim=True) + target_mean = target.mean(dim=dim, keepdim=True) + + new_noise_mean = mix * target_mean + (1 - mix) * noise_mean + + noise = noise - noise_mean + new_noise_mean + return noise + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(noise, noise_offset): + if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001): + return noise + if len(noise.shape) > 4: + raise ValueError("Applying noise offset not supported for video models at this time.") + noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) + return noise + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import PromptEmbeds + + +def concat_prompt_embeddings( + unconditional: 'PromptEmbeds', + conditional: 'PromptEmbeds', + n_imgs: int=0, +): + from toolkit.stable_diffusion_model import PromptEmbeds + text_embeds = torch.cat( + [unconditional.text_embeds, conditional.text_embeds] + ).repeat_interleave(n_imgs, dim=0) + pooled_embeds = None + if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None: + pooled_embeds = torch.cat( + [unconditional.pooled_embeds, conditional.pooled_embeds] + ).repeat_interleave(n_imgs, dim=0) + return PromptEmbeds([text_embeds, pooled_embeds]) + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +if TYPE_CHECKING: + from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + + +def text_tokenize( + tokenizer: 'CLIPTokenizer', + prompts: list[str], + truncate: bool = True, + max_length: int = None, + max_length_multiplier: int = 4, +): + # allow fo up to 4x the max length for long prompts + if max_length is None: + if truncate: + max_length = tokenizer.model_max_length + else: + # allow up to 4x the max length for long prompts + max_length = tokenizer.model_max_length * max_length_multiplier + + input_ids = tokenizer( + prompts, + padding='max_length', + max_length=max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + if truncate or max_length == tokenizer.model_max_length: + return input_ids + else: + # remove additional padding + num_chunks = input_ids.shape[1] // tokenizer.model_max_length + chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1) + + # New list to store non-redundant chunks + non_redundant_chunks = [] + + for chunk in chunks: + if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element + non_redundant_chunks.append(chunk) + + input_ids = torch.cat(non_redundant_chunks, dim=1) + return input_ids + + +# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 +def text_encode_xl( + text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'], + tokens: torch.FloatTensor, + num_images_per_prompt: int = 1, + max_length: int = 77, # not sure what default to put here, always pass one? + truncate: bool = True, +): + if truncate: + # normal short prompt 77 tokens max + prompt_embeds = text_encoder( + tokens.to(text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + else: + # handle long prompts + prompt_embeds_list = [] + tokens = tokens.to(text_encoder.device) + pooled_prompt_embeds = None + for i in range(0, tokens.shape[-1], max_length): + # todo run it through the in a single batch + section_tokens = tokens[:, i: i + max_length] + embeds = text_encoder(section_tokens, output_hidden_states=True) + pooled_prompt_embed = embeds[0] + if pooled_prompt_embeds is None: + # we only want the first ( I think??) + pooled_prompt_embeds = pooled_prompt_embed + prompt_embed = embeds.hidden_states[-2] # always penultimate layer + prompt_embeds_list.append(prompt_embed) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompts_xl( + tokenizers: list['CLIPTokenizer'], + text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']], + prompts: list[str], + prompts2: Union[list[str], None], + num_images_per_prompt: int = 1, + use_text_encoder_1: bool = True, # sdxl + use_text_encoder_2: bool = True, # sdxl + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +) -> tuple[torch.FloatTensor, torch.FloatTensor]: + # text_encoder and text_encoder_2's penuultimate layer's output + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + if prompts2 is None: + prompts2 = prompts + + for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)): + # todo, we are using a blank string to ignore that encoder for now. + # find a better way to do this (zeroing?, removing it from the unet?) + prompt_list_to_use = prompts if idx == 0 else prompts2 + if idx == 0 and not use_text_encoder_1: + prompt_list_to_use = ["" for _ in prompts] + if idx == 1 and not use_text_encoder_2: + prompt_list_to_use = ["" for _ in prompts] + + if dropout_prob > 0.0: + # randomly drop out prompts + prompt_list_to_use = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use + ] + + text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length) + # set the max length for the next one + if idx == 0: + max_length = text_tokens_input_ids.shape[-1] + + text_embeds, pooled_text_embeds = text_encode_xl( + text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length, + truncate=truncate + ) + + text_embeds_list.append(text_embeds) + + bs_embed = pooled_text_embeds.shape[0] + pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds + +def encode_prompts_sd3( + tokenizers: list['CLIPTokenizer'], + text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]], + prompts: list[str], + num_images_per_prompt: int = 1, + truncate: bool = True, + max_length=None, + dropout_prob=0.0, + pipeline = None, +): + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + + prompt_2 = prompts + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompts + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + device = text_encoders[0].device + + prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds( + prompt=prompts, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = pipeline._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + device=device + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + return prompt_embeds, pooled_prompt_embeds + + +# ref for long prompts https://github.com/huggingface/diffusers/issues/2136 +def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None): + if max_length is None and not truncate: + raise ValueError("max_length must be set if truncate is True") + try: + tokens = tokens.to(text_encoder.device) + except Exception as e: + print(e) + print("tokens.device", tokens.device) + print("text_encoder.device", text_encoder.device) + raise e + + if truncate: + return text_encoder(tokens)[0] + else: + # handle long prompts + prompt_embeds_list = [] + for i in range(0, tokens.shape[-1], max_length): + prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0] + prompt_embeds_list.append(prompt_embeds) + + return torch.cat(prompt_embeds_list, dim=1) + + +def encode_prompts( + tokenizer: 'CLIPTokenizer', + text_encoder: 'CLIPTextModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + max_length = tokenizer.model_max_length + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length) + text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length) + + return text_embeddings + + +def encode_prompts_pixart( + tokenizer: 'T5Tokenizer', + text_encoder: 'T5EncoderModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + # See Section 3.1. of the paper. + max_length = 120 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + text_inputs = tokenizer( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(text_encoder.device) + + text_input_ids = text_input_ids.to(text_encoder.device) + + prompt_embeds = text_encoder(text_input_ids, attention_mask=prompt_attention_mask) + + return prompt_embeds.last_hidden_state, prompt_attention_mask + + +def encode_prompts_auraflow( + tokenizer: 'T5Tokenizer', + text_encoder: 'UMT5EncoderModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + max_length = 256 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + device = text_encoder.device + + text_inputs = tokenizer( + prompts, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_input_ids = text_inputs["input_ids"] + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + prompt_embeds = text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + return prompt_embeds, prompt_attention_mask + +def encode_prompts_flux( + tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']], + text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']], + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, + attn_mask: bool = False, +): + if max_length is None: + max_length = 512 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + device = text_encoder[0].device + dtype = text_encoder[0].dtype + + batch_size = len(prompts) + + # clip + text_inputs = tokenizer[0]( + prompts, + padding="max_length", + max_length=tokenizer[0].model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + pooled_prompt_embeds = prompt_embeds.pooler_output + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device) + + # T5 + text_inputs = tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + if attn_mask: + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device) + + return prompt_embeds, pooled_prompt_embeds + + +# for XL +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, +): + if dynamic_crops: + # random float scale between 1 and 3 + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + # random position + crops_coords_top_left = ( + torch.randint(0, original_size[0] - height, (1,)).item(), + torch.randint(0, original_size[1] - width, (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + # this is expected as 6 + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + # this is expected as 2816 + passed_add_embed_dim = ( + UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 + + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 + ) + if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: + raise ValueError( + f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + +def concat_embeddings( + unconditional: torch.FloatTensor, + conditional: torch.FloatTensor, + n_imgs: int, +): + return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) + + +def add_all_snr_to_noise_scheduler(noise_scheduler, device): + try: + if hasattr(noise_scheduler, "all_snr"): + return + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + noise_scheduler.all_snr = all_snr.to(device) + except Exception as e: + # just move on + pass + + +def get_all_snr(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return noise_scheduler.all_snr.to(device) + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + return all_snr.to(device) + +class LearnableSNRGamma: + """ + This is a trainer for learnable snr gamma + It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps + """ + def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'): + self.device = device + self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler + self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device)) + self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device)) + self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device)) + self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device)) + self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01) + self.buffer = [] + self.max_buffer_size = 20 + + def forward(self, loss, timesteps): + # do a our train loop for lsnr here and return our values detached + loss = loss.detach() + with torch.no_grad(): + loss_chunks = torch.chunk(loss, loss.shape[0], dim=0) + for loss_chunk in loss_chunks: + self.buffer.append(loss_chunk.mean().detach()) + if len(self.buffer) > self.max_buffer_size: + self.buffer.pop(0) + all_snr = get_all_snr(self.noise_scheduler, loss.device) + snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device) + base_snrs = snr.clone().detach() + snr.requires_grad = True + snr = (snr + self.offset_1) * self.scale + self.offset_2 + + gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr) + snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr + snr_adjusted_loss = loss * snr_weight + with torch.no_grad(): + target = torch.mean(torch.stack(self.buffer)).detach() + + # local_loss = torch.mean(torch.abs(snr_adjusted_loss - target)) + squared_differences = (snr_adjusted_loss - target) ** 2 + local_loss = torch.mean(squared_differences) + local_loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach() + + +def apply_learnable_snr_gos( + loss, + timesteps, + learnable_snr_trainer: LearnableSNRGamma +): + + snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps) + + snr = (snr + offset_1) * scale + offset_2 + + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr + snr_adjusted_loss = loss * snr_weight + + return snr_adjusted_loss + + +def apply_snr_weight( + loss, + timesteps, + noise_scheduler: Union['DDPMScheduler'], + gamma, + fixed=False, +): + # will get it from noise scheduler if exist or will calculate it if not + all_snr = get_all_snr(noise_scheduler, loss.device) + # step_indices = [] + # for t in timesteps: + # for i, st in enumerate(noise_scheduler.timesteps): + # if st == t: + # step_indices.append(i) + # break + # this breaks on some schedulers + # step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps] + + offset = 0 + if noise_scheduler.timesteps[0] == 1000: + offset = 1 + snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + if fixed: + snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr + else: + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) + snr_adjusted_loss = loss * snr_weight + + return snr_adjusted_loss + + +def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler): + mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0) + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_output.shape[0]): + sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, + dtype=model_output.dtype, device=model_output.device) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx] + out_chunks.append(out) + return torch.cat(out_chunks, dim=0) diff --git a/ai-toolkit/toolkit/unloader.py b/ai-toolkit/toolkit/unloader.py new file mode 100644 index 0000000000000000000000000000000000000000..7b1670dc80cce0829c5960dcaac97e429717f6f1 --- /dev/null +++ b/ai-toolkit/toolkit/unloader.py @@ -0,0 +1,80 @@ +import gc +import torch +from toolkit.basic import flush +from toolkit.memory_management import MemoryManager +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel + + +class FakeTextEncoder(torch.nn.Module): + def __init__(self, device, dtype): + super().__init__() + # register a dummy parameter to avoid errors in some cases + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + self._device = device + self._dtype = dtype + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This is a fake text encoder and should not be used for inference." + ) + return None + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + def to(self, *args, **kwargs): + return self + + +def _detach_and_cpu(te: torch.nn.Module): + MemoryManager.detach(te) + # bypass any nopped-out .to() override and force an actual CPU move + torch.nn.Module.to(te, 'cpu') + + +def unload_text_encoder(model: "BaseModel"): + # unload the text encoder in a way that will work with all models and will not throw errors + # we need to make it appear as a text encoder module without actually having one so all + # to functions and what not will work. + + if model.text_encoder is not None: + if isinstance(model.text_encoder, list): + text_encoder_list = [] + pipe = model.pipeline + + # the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc. + if hasattr(pipe, "text_encoder"): + _detach_and_cpu(pipe.text_encoder) + te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype) + text_encoder_list.append(te) + pipe.text_encoder = te + + i = 2 + while hasattr(pipe, f"text_encoder_{i}"): + real_te = getattr(pipe, f"text_encoder_{i}") + _detach_and_cpu(real_te) + te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype) + text_encoder_list.append(te) + setattr(pipe, f"text_encoder_{i}", te) + i += 1 + model.text_encoder = text_encoder_list + else: + # only has a single text encoder + _detach_and_cpu(model.text_encoder) + model.text_encoder = FakeTextEncoder( + device=model.device_torch, + dtype=model.torch_dtype + ) + + torch.cuda.empty_cache() + gc.collect() + flush() diff --git a/ai-toolkit/toolkit/util/blended_blur_noise.py b/ai-toolkit/toolkit/util/blended_blur_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..86f46ddbfac0eded5be3449ab33123adcf6c7d34 --- /dev/null +++ b/ai-toolkit/toolkit/util/blended_blur_noise.py @@ -0,0 +1,84 @@ +import torch + +cached_multipier = None + +def get_multiplier(timesteps, num_timesteps=1000): + global cached_multipier + if cached_multipier is None: + # creates a bell curve + x = torch.arange(num_timesteps, dtype=torch.float32) + y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2) + + # Shift minimum to 0 + y_shifted = y - y.min() + + # Scale to make mean 1 + cached_multipier = y_shifted * (num_timesteps / y_shifted.sum()) + + scale_list = [] + # get the idx multiplier for each timestep + for i in range(timesteps.shape[0]): + idx = min(int(timesteps[i].item()) - 1, 0) + scale_list.append(cached_multipier[idx:idx + 1]) + + scales = torch.cat(scale_list, dim=0) + + batch_multiplier = scales.view(-1, 1, 1, 1) + + return batch_multiplier + + +def get_blended_blur_noise(latents, noise, timestep): + latent_chunks = torch.chunk(latents, latents.shape[0], dim=0) + + # timestep is 1000 to 0 + # timestep = timestep.to(latents.device, dtype=latents.dtype) + + # scale it so timestep 1000 is 0 and 0 is 2 + # blur_strength = value_map(timestep, 1000, 0, 0, 1.0) + # blur_strength = timestep / 500.0 + # blur_strength = blur_strength.view(-1, 1, 1, 1) + + # scale to 2.0 max + # blur_strength = get_multiplier(timestep).to( + # latents.device, dtype=latents.dtype + # ) * 2.0 + + # blur_strength = 2.0 + + blurred_latent_chunks = [] + for i in range(len(latent_chunks)): + latent_chunk = latent_chunks[i] + # get two random scalers 0.1 to 0.9 + # scaler1 = random.uniform(0.2, 0.8) + scaler1 = 0.25 + scaler2 = scaler1 + + # shrink latents by 1/4 and bring them back for blurring using interpolation + blur_latents = torch.nn.functional.interpolate( + latent_chunk, + size=(int(latents.shape[2] * scaler1), int(latents.shape[3] * scaler2)), + mode='bilinear', + align_corners=False + ) + blur_latents = torch.nn.functional.interpolate( + blur_latents, + size=(latents.shape[2], latents.shape[3]), + mode='bilinear', + align_corners=False + ) + # only the difference of the blur from ground truth + blur_latents = blur_latents - latent_chunk + blurred_latent_chunks.append(blur_latents) + + blur_latents = torch.cat(blurred_latent_chunks, dim=0) + + + # make random strength along batch 0 to 1 + blur_strength = torch.rand((latents.shape[0], 1, 1, 1), device=latents.device, dtype=latents.dtype) * 2 + + blur_latents = blur_latents * blur_strength + + noise = noise + blur_latents + return noise + \ No newline at end of file diff --git a/ai-toolkit/toolkit/util/get_model.py b/ai-toolkit/toolkit/util/get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..545175ddf943ddabc06696b5070146adf4293bc4 --- /dev/null +++ b/ai-toolkit/toolkit/util/get_model.py @@ -0,0 +1,50 @@ +import os +from typing import List +from toolkit.models.base_model import BaseModel +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.config_modules import ModelConfig +from toolkit.paths import TOOLKIT_ROOT +import importlib +import pkgutil + +from toolkit.models.wan21 import Wan21, Wan21I2V +from toolkit.models.cogview4 import CogView4 + +BUILT_IN_MODELS = [ + Wan21, + Wan21I2V, + CogView4, +] + + +def get_all_models() -> List[BaseModel]: + extension_folders = ['extensions', 'extensions_built_in'] + + # This will hold the classes from all extension modules + all_model_classes: List[BaseModel] = BUILT_IN_MODELS + + # Iterate over all directories (i.e., packages) in the "extensions" directory + for sub_dir in extension_folders: + extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) + for (_, name, _) in pkgutil.iter_modules([extensions_dir]): + try: + # Import the module + module = importlib.import_module(f"{sub_dir}.{name}") + # Get the value of the AI_TOOLKIT_MODELS variable + models = getattr(module, "AI_TOOLKIT_MODELS", None) + # Check if the value is a list + if isinstance(models, list): + # Iterate over the list and add the classes to the main list + all_model_classes.extend(models) + except ImportError as e: + print(f"Failed to import the {name} module. Error: {str(e)}") + return all_model_classes + + +def get_model_class(config: ModelConfig): + all_models = get_all_models() + for ModelClass in all_models: + if ModelClass.arch == config.arch: + return ModelClass + # default to the legacy model + return StableDiffusion diff --git a/ai-toolkit/toolkit/util/inverse_cfg.py b/ai-toolkit/toolkit/util/inverse_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..0c85544a95c1a81cd5f7f6e4cf9ca3408e92c81e --- /dev/null +++ b/ai-toolkit/toolkit/util/inverse_cfg.py @@ -0,0 +1,25 @@ +import torch + + +def inverse_classifier_guidance( + noise_pred_cond: torch.Tensor, + noise_pred_uncond: torch.Tensor, + guidance_scale: torch.Tensor +): + """ + Adjust the noise_pred_cond for the classifier free guidance algorithm + to ensure that the final noise prediction equals the original noise_pred_cond. + """ + # To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond + # based on the formula used in the algorithm. + # We derive the formula to find the correct adjustment for noise_pred_cond: + # noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1) + # It's important to check if guidance_scale is not 1 to avoid division by zero. + if guidance_scale == 1: + # If guidance_scale is 1, adjusting is not needed or possible in the same way, + # since it would lead to division by zero. This also means the algorithm inherently + # doesn't alter the noise_pred_cond in relation to noise_pred_uncond. + # Thus, we return the original values, though this situation might need special handling. + return noise_pred_cond + adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale + return adjusted_noise_pred_cond diff --git a/ai-toolkit/toolkit/util/ip_adapter_utils.py b/ai-toolkit/toolkit/util/ip_adapter_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e80643a189fdb187200816d7b9c2d7b86cb8395 --- /dev/null +++ b/ai-toolkit/toolkit/util/ip_adapter_utils.py @@ -0,0 +1,634 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + ip_value = ip_value.view( + batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + # print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose( + 1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +# for controlnet +class CNAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, num_tokens=4): + self.num_tokens = num_tokens + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + # only use text + encoder_hidden_states = encoder_hidden_states[:, :end_pos] + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CNAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, num_tokens=4): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.num_tokens = num_tokens + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + # only use text + encoder_hidden_states = encoder_hidden_states[:, :end_pos] + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens diff --git a/ai-toolkit/toolkit/util/losses.py b/ai-toolkit/toolkit/util/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..5328674adef0077683f96ab5f758c0744e205b73 --- /dev/null +++ b/ai-toolkit/toolkit/util/losses.py @@ -0,0 +1,93 @@ +import torch + + +_dwt = None + + +def _get_wavelet_loss(device, dtype): + global _dwt + if _dwt is not None: + return _dwt + + # init wavelets + from pytorch_wavelets import DWTForward + + # wave='db1' wave='haar' + dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype) + _dwt = dwt + return dwt + + +def wavelet_loss(model_pred, latents, noise): + model_pred = model_pred.float() + latents = latents.float() + noise = noise.float() + dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) + with torch.no_grad(): + model_input_xll, model_input_xh = dwt(latents) + model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind( + model_input_xh[0], dim=2 + ) + model_input = torch.cat( + [model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1 + ) + + # reverse the noise to get the model prediction of the pure latents + model_pred = noise - model_pred + + model_pred_xll, model_pred_xh = dwt(model_pred) + model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind( + model_pred_xh[0], dim=2 + ) + model_pred = torch.cat( + [model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1 + ) + + return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") + + +def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler): + # this steps the on a 20 step timescale from the current step (50 idx steps ahead) + # and then reconstructs the original image at that timestep. This should lessen the error + # possible in high noise timesteps and make the flow smoother. + bs = model_pred.shape[0] + + noise_pred_chunks = torch.chunk(model_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + noise_chunks = torch.chunk(noise, bs) + + x0_pred_chunks = [] + + for idx in range(bs): + model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) + timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) + sample = noisy_latent_chunks[idx].to(torch.float32) + noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) + + # Initialize scheduler step index for this sample + scheduler._step_index = None + scheduler._init_step_index(timestep) + + # ---- Step +50 indices (or to the end) in sigma-space ---- + sigma = scheduler.sigmas[scheduler.step_index] + target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) + sigma_next = scheduler.sigmas[target_idx] + + # One-step update along the model-predicted direction + stepped = sample + (sigma_next - sigma) * model_output + + # ---- Inverse-Gaussian recovery at the target timestep ---- + t_01 = ( + (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) + ) + original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) + x0_pred_chunks.append(original_samples) + + predicted_images = torch.cat(x0_pred_chunks, dim=0) + + return torch.nn.functional.mse_loss( + predicted_images.float(), + latents.float().to(device=predicted_images.device), + reduction="none", + ) diff --git a/ai-toolkit/toolkit/util/mask.py b/ai-toolkit/toolkit/util/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..1b80f1caaadab19254789a50ca740dbde1eba7fc --- /dev/null +++ b/ai-toolkit/toolkit/util/mask.py @@ -0,0 +1,288 @@ +import torch +import numpy as np +import os +import torch.nn.functional as F +from PIL import Image +import time +import random + + +def generate_random_mask( + batch_size, + height=256, + width=256, + device='cuda', + min_coverage=0.2, + max_coverage=0.8, + num_blobs_range=(1, 3) +): + """ + Generate random blob masks for a batch of images. + Fast GPU version with smooth, non-circular blob shapes. + + Args: + batch_size (int): Number of masks to generate + height (int): Height of the mask + width (int): Width of the mask + device (str): Device to run the computation on ('cuda' or 'cpu') + min_coverage (float): Minimum percentage of the image to be covered (0-1) + max_coverage (float): Maximum percentage of the image to be covered (0-1) + num_blobs_range (tuple): Range of number of blobs (min, max) + + Returns: + torch.Tensor: Binary masks with shape (batch_size, 1, height, width) + """ + # Initialize masks on GPU + masks = torch.zeros((batch_size, 1, height, width), device=device) + + # Pre-compute coordinate grid on GPU + y_indices = torch.arange(height, device=device).view( + height, 1).expand(height, width) + x_indices = torch.arange(width, device=device).view( + 1, width).expand(height, width) + + # Prepare gaussian kernels for smoothing + small_kernel = get_gaussian_kernel(7, 1.0).to(device) + small_kernel = small_kernel.view(1, 1, 7, 7) + + large_kernel = get_gaussian_kernel(15, 2.5).to(device) + large_kernel = large_kernel.view(1, 1, 15, 15) + + # Constants + max_radius = min(height, width) // 3 + min_radius = min(height, width) // 8 + + # For each mask in the batch + for b in range(batch_size): + # Determine number of blobs for this mask + num_blobs = np.random.randint( + num_blobs_range[0], num_blobs_range[1] + 1) + + # Target coverage for this mask + target_coverage = np.random.uniform(min_coverage, max_coverage) + + # Initialize this mask + mask = torch.zeros(1, 1, height, width, device=device) + + # Generate blobs with smoother edges + for _ in range(num_blobs): + # Create a low-frequency noise field first (for smooth organic shapes) + noise_field = torch.zeros(height, width, device=device) + + # Use low-frequency sine waves to create base shape distortion + # This creates smoother warping compared to pure random noise + num_waves = np.random.randint(2, 5) + for i in range(num_waves): + freq_x = np.random.uniform(1.0, 3.0) * np.pi / width + freq_y = np.random.uniform(1.0, 3.0) * np.pi / height + phase_x = np.random.uniform(0, 2 * np.pi) + phase_y = np.random.uniform(0, 2 * np.pi) + amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5) + + # Generate smooth wave patterns + wave = torch.sin(x_indices * freq_x + phase_x) * \ + torch.sin(y_indices * freq_y + phase_y) * amp + noise_field += wave + + # Basic ellipse parameters + center_y = np.random.randint(height//4, 3*height//4) + center_x = np.random.randint(width//4, 3*width//4) + radius = np.random.randint(min_radius, max_radius) + + # Squeeze and stretch the ellipse with random scaling + scale_y = np.random.uniform(0.6, 1.4) + scale_x = np.random.uniform(0.6, 1.4) + + # Random rotation + theta = np.random.uniform(0, 2 * np.pi) + cos_theta, sin_theta = np.cos(theta), np.sin(theta) + + # Calculate elliptical distance field + y_scaled = (y_indices - center_y) * scale_y + x_scaled = (x_indices - center_x) * scale_x + + # Apply rotation + rotated_y = y_scaled * cos_theta - x_scaled * sin_theta + rotated_x = y_scaled * sin_theta + x_scaled * cos_theta + + # Compute distances + distances = torch.sqrt(rotated_y**2 + rotated_x**2) + + # Apply the smooth noise field to the distance field + perturbed_distances = distances + noise_field + + # Create base blob + blob = (perturbed_distances < radius).float( + ).unsqueeze(0).unsqueeze(0) + + # Apply strong smoothing for very smooth edges + # Double smoothing to get really organic edges + blob = F.pad(blob, (7, 7, 7, 7), mode='reflect') + blob = F.conv2d(blob, large_kernel, padding=0) + + # Apply threshold to get a nice shape + rand_threshold = np.random.uniform(0.3, 0.6) + blob = (blob > rand_threshold).float() + + # Apply second smoothing pass + blob = F.pad(blob, (3, 3, 3, 3), mode='reflect') + blob = F.conv2d(blob, small_kernel, padding=0) + blob = (blob > 0.5).float() + + # Add to mask + mask = torch.maximum(mask, blob) + + # Ensure desired coverage + current_coverage = mask.mean().item() + + # Scale if needed to match target coverage + if current_coverage > 0: # Avoid division by zero + if current_coverage < target_coverage * 0.7: # Too small + # Dilate mask to increase coverage + mask = F.pad(mask, (2, 2, 2, 2), mode='reflect') + mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0) + elif current_coverage > target_coverage * 1.3: # Too large + # Erode mask to decrease coverage + mask = F.pad(mask, (1, 1, 1, 1), mode='reflect') + mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0) + mask = (mask > 0.7).float() + + # Final smooth and threshold + mask = F.pad(mask, (3, 3, 3, 3), mode='reflect') + mask = F.conv2d(mask, small_kernel, padding=0) + mask = (mask > 0.5).float() + + # Add to batch + masks[b] = mask + + return masks + + +def get_gaussian_kernel(kernel_size=5, sigma=1.0): + """ + Returns a 2D Gaussian kernel. + """ + # Create 1D kernels + x = torch.linspace(-sigma * 2, sigma * 2, kernel_size) + x = x.view(1, -1).repeat(kernel_size, 1) + y = x.transpose(0, 1) + + # 2D Gaussian + gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) + gaussian /= gaussian.sum() + + return gaussian + + +def save_masks_as_images(masks, suffix="", output_dir="output"): + """ + Save generated masks as RGB JPG images using PIL. + """ + os.makedirs(output_dir, exist_ok=True) + + batch_size = masks.shape[0] + for i in range(batch_size): + # Convert mask to numpy array + mask = masks[i, 0].cpu().numpy() + + # Scale to 0-255 range and convert to uint8 + mask_255 = (mask * 255).astype(np.uint8) + + # Create RGB image (white mask on black background) + rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2) + + # Convert to PIL Image and save + img = Image.fromarray(rgb_mask) + img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95) + + +def random_dialate_mask(mask, max_percent=0.05): + """ + Randomly dialates a binary mask with a kernel of random size. + + Args: + mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width] + max_percent (float): Maximum kernel size as a percentage of the mask size + + Returns: + torch.Tensor: Dialated mask with the same shape as input + """ + + size = mask.shape[-1] + max_size = int(size * max_percent) + + # Handle case where max_size is too small + if max_size < 3: + max_size = 3 + + batch_chunks = torch.chunk(mask, mask.shape[0], dim=0) + out_chunks = [] + + for i in range(len(batch_chunks)): + chunk = batch_chunks[i] + + # Ensure kernel size is odd for proper padding + kernel_size = np.random.randint(1, max_size) + + # If kernel_size is less than 2, keep the original mask + if kernel_size < 2: + out_chunks.append(chunk) + continue + + # Make sure kernel size is odd + if kernel_size % 2 == 0: + kernel_size += 1 + + # Create normalized dilation kernel + kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size) + + # Pad the mask for convolution + padding = kernel_size // 2 + padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0) + + # Apply convolution + dilated = F.conv2d(padded_mask, kernel) + + # Random threshold for varied dilation effect + threshold = np.random.uniform(0.2, 0.8) + + # Apply threshold + dilated = (dilated > threshold).float() + + out_chunks.append(dilated) + + return torch.cat(out_chunks, dim=0) + + +if __name__ == "__main__": + # Parameters + batch_size = 20 + height = 256 + width = 256 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + print(f"Generating {batch_size} random blob masks on {device}...") + + for i in range(5): + # time it + start = time.time() + masks = generate_random_mask( + batch_size=batch_size, + height=height, + width=width, + device=device, + min_coverage=0.2, + max_coverage=0.8, + num_blobs_range=(1, 3) + ) + dialation = random_dialate_mask(masks) + print(f"Generated {batch_size} masks with shape: {masks.shape}") + end = time.time() + # print time in milliseconds + print(f"Time taken: {(end - start)*1000:.2f} ms") + + print(f"Saving masks to 'output' directory...") + save_masks_as_images(masks) + save_masks_as_images(dialation, suffix="_dilated" ) + + print("Done!") diff --git a/ai-toolkit/toolkit/util/quantize.py b/ai-toolkit/toolkit/util/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..85a0652d679e451423b44a0dcd2fb0ede55cdb55 --- /dev/null +++ b/ai-toolkit/toolkit/util/quantize.py @@ -0,0 +1,356 @@ +from fnmatch import fnmatch +from typing import List, Optional, Union, TYPE_CHECKING +import torch + +from optimum.quanto.quantize import _quantize_submodule +from optimum.quanto.tensor import Optimizer, qtype, qtypes +from torchao.quantization.quant_api import ( + quantize_ as torchao_quantize_, + Float8WeightOnlyConfig, + UIntXWeightOnlyConfig, + Int8WeightOnlyConfig +) +from optimum.quanto import freeze +from tqdm import tqdm +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download + +from toolkit.print import print_acc +import os + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel + +# the quantize function in quanto had a bug where it was using exclude instead of include + +Q_MODULES = [ + "QLinear", + "QConv2d", + "QEmbedding", + "QBatchNorm2d", + "QLayerNorm", + "QConvTranspose2d", + "QEmbeddingBag", +] + +torchao_qtypes = { + # "int4": Int4WeightOnlyConfig(), + "uint2": UIntXWeightOnlyConfig(torch.uint2), + "uint3": UIntXWeightOnlyConfig(torch.uint3), + "uint4": UIntXWeightOnlyConfig(torch.uint4), + "uint5": UIntXWeightOnlyConfig(torch.uint5), + "uint6": UIntXWeightOnlyConfig(torch.uint6), + "uint7": UIntXWeightOnlyConfig(torch.uint7), + "uint8": UIntXWeightOnlyConfig(torch.uint8), + "int8": Int8WeightOnlyConfig(), + "float8": Float8WeightOnlyConfig(), +} + + +class aotype: + def __init__(self, name: str): + self.name = name + self.config = torchao_qtypes[name] + + +def get_qtype(qtype: Union[str, qtype]) -> qtype: + if qtype in torchao_qtypes: + return aotype(qtype) + if isinstance(qtype, str): + return qtypes[qtype] + else: + return qtype + + +def is_quantized_tensor(t) -> bool: + # torchao stores quantized weights as tensor subclasses (e.g. AffineQuantizedTensor) under torchao.* + # that still report as nn.Parameter and expose .dequantize(). (quanto is handled separately.) + return 'torchao' in type(t).__module__ and hasattr(t, 'dequantize') + + +def dequantize_if_quantized(t): + return t.dequantize() if is_quantized_tensor(t) else t + + +def get_torchao_config(qtype): + # returns the torchao quantization config for a given qtype string, or None if it isn't torchao + if qtype is None: + return None + try: + q = get_qtype(qtype) + except Exception: + return None + return q.config if isinstance(q, aotype) else None + + +def requantize_module_weight(module, fp_weight, orig_dtype, config) -> None: + """Write a full precision weight back into module.weight, re-quantizing in place if a torchao + config is provided so the module stays quantized (used by the continuous merge/reset method). + If config is None the weight is left in full precision.""" + module.weight = torch.nn.Parameter(fp_weight.to(orig_dtype), requires_grad=False) + if config is not None: + torchao_quantize_(module, config) + + +def quantize( + model: torch.nn.Module, + weights: Optional[Union[str, qtype, aotype]] = None, + activations: Optional[Union[str, qtype]] = None, + optimizer: Optional[Optimizer] = None, + include: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, +): + """Quantize the specified model submodules + + Recursively quantize the submodules of the specified parent model. + + Only modules that have quantized counterparts will be quantized. + + If include patterns are specified, the submodule name must match one of them. + + If exclude patterns are specified, the submodule must not match one of them. + + Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See + https://docs.python.org/3/library/fnmatch.html for more details. + + Note: quantization happens in-place and modifies the original model and its descendants. + + Args: + model (`torch.nn.Module`): the model whose submodules will be quantized. + weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. + activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. + include (`Optional[Union[str, List[str]]]`): + Patterns constituting the allowlist. If provided, module names must match at + least one pattern from the allowlist. + exclude (`Optional[Union[str, List[str]]]`): + Patterns constituting the denylist. If provided, module names must not match + any patterns from the denylist. + """ + if include is not None: + include = [include] if isinstance(include, str) else include + if exclude is not None: + exclude = [exclude] if isinstance(exclude, str) else exclude + for name, m in model.named_modules(): + if include is not None and not any( + fnmatch(name, pattern) for pattern in include + ): + continue + if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): + continue + try: + # check if m is QLinear or QConv2d + if m.__class__.__name__ in Q_MODULES: + continue + else: + if isinstance(weights, aotype): + torchao_quantize_(m, weights.config) + else: + _quantize_submodule( + model, + name, + m, + weights=weights, + activations=activations, + optimizer=optimizer, + ) + except Exception as e: + print(f"Failed to quantize {name}: {e}") + # raise e + + +def quantize_model( + base_model: "BaseModel", + model_to_quantize: torch.nn.Module, +): + from toolkit.dequantize import patch_dequantization_on_save + + if not hasattr(base_model, "get_transformer_block_names"): + raise ValueError( + "The model to quantize must have a method `get_transformer_block_names`." + ) + + # patch the state dict method + patch_dequantization_on_save(model_to_quantize) + + # sensitive modules to keep in full precision (fnmatch patterns) + exclude_modules = base_model.get_quantization_exclude_modules() or [] + + if base_model.model_config.accuracy_recovery_adapter is not None: + from toolkit.config_modules import NetworkConfig + from toolkit.lora_special import LoRASpecialNetwork + + # we need to load and quantize with an accuracy recovery adapter + # todo handle hf repos + load_lora_path = base_model.model_config.accuracy_recovery_adapter + + if not os.path.exists(load_lora_path): + # not local file, grab from the hub + + path_split = load_lora_path.split("/") + if len(path_split) > 3: + raise ValueError( + "The accuracy recovery adapter path must be a local path or for a hf repo, 'username/repo_name/filename.safetensors'." + ) + repo_id = f"{path_split[0]}/{path_split[1]}" + print_acc(f"Grabbing lora from the hub: {load_lora_path}") + new_lora_path = hf_hub_download( + repo_id, + filename=path_split[-1], + ) + # replace the path + load_lora_path = new_lora_path + + # build the lora config based on the lora weights + lora_state_dict = load_file(load_lora_path) + + if hasattr(base_model, "convert_lora_weights_before_load"): + lora_state_dict = base_model.convert_lora_weights_before_load(lora_state_dict) + + network_config = { + "type": "lora", + "network_kwargs": {"only_if_contains": []}, + "transformer_only": False, + } + first_key = list(lora_state_dict.keys())[0] + first_weight = lora_state_dict[first_key] + # if it starts with lycoris and includes lokr + if first_key.startswith("lycoris") and any( + "lokr" in key for key in lora_state_dict.keys() + ): + network_config["type"] = "lokr" + + network_kwargs = {} + + # find firse loraA weight + if network_config["type"] == "lora": + linear_dim = None + for key, value in lora_state_dict.items(): + if "lora_A" in key: + linear_dim = int(value.shape[0]) + break + linear_alpha = linear_dim + network_config["linear"] = linear_dim + network_config["linear_alpha"] = linear_alpha + + # we build the keys to match every key + only_if_contains = [] + for key in lora_state_dict.keys(): + contains_key = key.split(".lora_")[0] + if contains_key not in only_if_contains: + only_if_contains.append(contains_key) + + network_kwargs["only_if_contains"] = only_if_contains + elif network_config["type"] == "lokr": + # find the factor + largest_factor = 0 + for key, value in lora_state_dict.items(): + if "lokr_w1" in key: + factor = int(value.shape[0]) + if factor > largest_factor: + largest_factor = factor + network_config["lokr_full_rank"] = True + network_config["lokr_factor"] = largest_factor + + only_if_contains = [] + for key in lora_state_dict.keys(): + if "lokr_w1" in key: + contains_key = key.split(".lokr_w1")[0] + contains_key = contains_key.replace("lycoris_", "") + if contains_key not in only_if_contains: + only_if_contains.append(contains_key) + network_kwargs["only_if_contains"] = only_if_contains + + if hasattr(base_model, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = base_model.target_lora_modules + + # todo auto grab these + # get dim and scale + network_config = NetworkConfig(**network_config) + + network = LoRASpecialNetwork( + text_encoder=None, + unet=model_to_quantize, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + # conv_lora_dim=self.network_config.conv, + # conv_alpha=self.network_config.conv_alpha, + train_unet=True, + train_text_encoder=False, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_transformer=base_model.is_transformer, + base_model=base_model, + is_ara=True, + **network_kwargs + ) + network.apply_to( + None, model_to_quantize, apply_text_encoder=False, apply_unet=True + ) + network.force_to(base_model.device_torch, dtype=base_model.torch_dtype) + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + network.eval() + network.is_active = True + network.can_merge_in = False + base_model.accuracy_recovery_adapter = network + + # quantize it + lora_exclude_modules = [] + quantization_type = get_qtype(base_model.model_config.qtype) + for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"): + # the lora has already hijacked the original module + orig_module = lora_module.org_module[0] + orig_module.to(base_model.torch_dtype) + # make the params not require gradients + for param in orig_module.parameters(): + param.requires_grad = False + quantize(orig_module, weights=quantization_type) + freeze(orig_module) + module_name = lora_module.lora_name.replace('$$', '.').replace('transformer.', '') + lora_exclude_modules.append(module_name) + if base_model.model_config.low_vram: + # move it back to cpu + orig_module.to("cpu") + pass + # quantize additional layers + print_acc(" - quantizing additional layers") + quantization_type = get_qtype('uint8') + quantize( + model_to_quantize, + weights=quantization_type, + exclude=lora_exclude_modules + exclude_modules + ) + else: + # quantize model the original way without an accuracy recovery adapter + # move and quantize only certain pieces at a time. + quantization_type = get_qtype(base_model.model_config.qtype) + # all_blocks = list(model_to_quantize.transformer_blocks) + all_blocks: List[torch.nn.Module] = [] + transformer_block_names = base_model.get_transformer_block_names() + for name in transformer_block_names: + # name may be a dotted path for models that nest their blocks + # (e.g. hidream_o1's "model.language_model.layers"). + block_list = model_to_quantize + for part in name.split('.'): + block_list = getattr(block_list, part, None) + if block_list is None: + break + if block_list is not None: + all_blocks += list(block_list) + base_model.print_and_status_update( + f" - quantizing {len(all_blocks)} transformer blocks" + ) + for block in tqdm(all_blocks): + block.to(base_model.device_torch, dtype=base_model.torch_dtype, non_blocking=True) + quantize(block, weights=quantization_type) + freeze(block) + block.to("cpu", non_blocking=True) + + # todo, on extras find a universal way to quantize them on device and move them back to their original + # device without having to move the transformer blocks to the device first + base_model.print_and_status_update(" - quantizing extras") + # model_to_quantize.to(base_model.device_torch, dtype=base_model.torch_dtype) + quantize(model_to_quantize, weights=quantization_type, exclude=exclude_modules) + freeze(model_to_quantize) diff --git a/ai-toolkit/toolkit/util/shuffle.py b/ai-toolkit/toolkit/util/shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..940dee17813298e004821c1a342ff73d10e25c9b --- /dev/null +++ b/ai-toolkit/toolkit/util/shuffle.py @@ -0,0 +1,54 @@ +import torch +import random +import numpy as np + +def shuffle_tensor_along_axis(tensor, axis=0, seed=None): + """ + Shuffle a tensor along a specified axis without affecting the global random state. + + Args: + tensor (torch.Tensor): The input tensor to shuffle + axis (int, optional): The axis along which to shuffle. Defaults to 0. + seed (int, optional): Random seed for reproducibility. Defaults to None. + + Returns: + torch.Tensor: The shuffled tensor + """ + # Clone the tensor to avoid in-place modifications + shuffled_tensor = tensor.clone() + + # Store original random states + torch_state = torch.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + + try: + # Set seed if provided + if seed is not None: + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + # Get the size of the dimension to shuffle + dim_size = tensor.shape[axis] + + # Generate random indices for shuffling + indices = torch.randperm(dim_size) + + # Create a slice object to shuffle along the specified axis + slices = [slice(None)] * tensor.dim() + slices[axis] = indices + + # Apply the shuffle + shuffled_tensor = tensor[slices] + + except Exception as e: + raise RuntimeError(f"Error during shuffling: {e}") + + finally: + # Restore original random states + torch.set_rng_state(torch_state) + np.random.set_state(np_state) + random.setstate(py_state) + + return shuffled_tensor \ No newline at end of file diff --git a/ai-toolkit/toolkit/util/vae.py b/ai-toolkit/toolkit/util/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7c405275ee1abf44a51d1f3cb599cd0c55d4cc --- /dev/null +++ b/ai-toolkit/toolkit/util/vae.py @@ -0,0 +1,20 @@ +from diffusers import AutoencoderKL + + +def load_vae(vae_path, dtype): + try: + vae = AutoencoderKL.from_pretrained( + vae_path, + torch_dtype=dtype, + ) + except Exception as e: + try: + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae", + torch_dtype=dtype, + ) + except Exception as e: + raise ValueError(f"Failed to load VAE from {vae_path}: {e}") + vae.to(dtype) + return vae diff --git a/ai-toolkit/ui/.gitignore b/ai-toolkit/ui/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..421d02ee01be72e67d366fb62597e1d928bf46bb --- /dev/null +++ b/ai-toolkit/ui/.gitignore @@ -0,0 +1,42 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.* +.yarn/* +!.yarn/patches +!.yarn/plugins +!.yarn/releases +!.yarn/versions + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# env files (can opt-in for committing if needed) +.env* + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts +aitk_db.db diff --git a/ai-toolkit/ui/README.md b/ai-toolkit/ui/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e215bc4ccf138bbc38ad58ad57e92135484b3c0f --- /dev/null +++ b/ai-toolkit/ui/README.md @@ -0,0 +1,36 @@ +This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app). + +## Getting Started + +First, run the development server: + +```bash +npm run dev +# or +yarn dev +# or +pnpm dev +# or +bun dev +``` + +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. + +You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. + +This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel. + +## Learn More + +To learn more about Next.js, take a look at the following resources: + +- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. +- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. + +You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome! + +## Deploy on Vercel + +The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. + +Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details. diff --git a/ai-toolkit/ui/cron/actions/processQueue.ts b/ai-toolkit/ui/cron/actions/processQueue.ts new file mode 100644 index 0000000000000000000000000000000000000000..175f613ed6d4c22baa86cae90f12e315c204d26f --- /dev/null +++ b/ai-toolkit/ui/cron/actions/processQueue.ts @@ -0,0 +1,71 @@ +import prisma from '../prisma'; + +import { Job, Queue } from '@prisma/client'; +import startJob from './startJob'; + +export default async function processQueue() { + const queues: Queue[] = await prisma.queue.findMany({ + orderBy: { + id: 'asc', + }, + }); + + for (const queue of queues) { + if (!queue.is_running) { + // stop any running jobs first + const runningJobs: Job[] = await prisma.job.findMany({ + where: { + status: 'running', + gpu_ids: queue.gpu_ids, + }, + }); + + for (const job of runningJobs) { + console.log(`Stopping job ${job.id} on GPU(s) ${job.gpu_ids}`); + await prisma.job.update({ + where: { id: job.id }, + data: { + return_to_queue: true, + info: 'Stopping job...', + }, + }); + } + } + if (queue.is_running) { + // first see if one is already running, status of running or stopping + const runningJob: Job | null = await prisma.job.findFirst({ + where: { + status: { in: ['running', 'stopping'] }, + gpu_ids: queue.gpu_ids, + }, + }); + + if (runningJob) { + // already running, nothing to do + continue; // skip to next queue + } else { + // find the next job in the queue + const nextJob: Job | null = await prisma.job.findFirst({ + where: { + status: 'queued', + gpu_ids: queue.gpu_ids, + }, + orderBy: { + queue_position: 'asc', + }, + }); + if (nextJob) { + console.log(`Starting job ${nextJob.id} on GPU(s) ${nextJob.gpu_ids}`); + await startJob(nextJob.id); + } else { + // no more jobs, stop the queue + console.log(`No more jobs in queue for GPU(s) ${queue.gpu_ids}, stopping queue`); + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: false }, + }); + } + } + } + } +} diff --git a/ai-toolkit/ui/cron/actions/startJob.ts b/ai-toolkit/ui/cron/actions/startJob.ts new file mode 100644 index 0000000000000000000000000000000000000000..8f0eb8bd655bbab7979d1756170a55555be11283 --- /dev/null +++ b/ai-toolkit/ui/cron/actions/startJob.ts @@ -0,0 +1,173 @@ +import prisma from '../prisma'; +import { Job } from '@prisma/client'; +import { spawn } from 'child_process'; +import path from 'path'; +import fs from 'fs'; +import { TOOLKIT_ROOT, getTrainingFolder, getHFToken } from '../paths'; +import { resolvePythonPath } from '../pythonPath'; +const isWindows = process.platform === 'win32'; + +const startAndWatchJob = (job: Job) => { + // starts and watches the job asynchronously + return new Promise(async (resolve, reject) => { + const jobID = job.id; + + // setup the training + const trainingRoot = await getTrainingFolder(); + + const trainingFolder = path.join(trainingRoot, job.name); + if (!fs.existsSync(trainingFolder)) { + fs.mkdirSync(trainingFolder, { recursive: true }); + } + + // make the config file + const configPath = path.join(trainingFolder, '.job_config.json'); + + //log to path + const logPath = path.join(trainingFolder, 'log.txt'); + + try { + // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num + // if the log path does not exist, create it + if (fs.existsSync(logPath)) { + const logsFolder = path.join(trainingFolder, 'logs'); + if (!fs.existsSync(logsFolder)) { + fs.mkdirSync(logsFolder, { recursive: true }); + } + + let num = 0; + while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) { + num++; + } + + fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`)); + } + } catch (e) { + console.error('Error moving log file:', e); + } + + // update the config dataset path + const jobConfig = JSON.parse(job.job_config); + jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db'); + + // write the config file + fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2)); + + const pythonPath = resolvePythonPath(); + + const runFilePath = path.join(TOOLKIT_ROOT, 'run.py'); + if (!fs.existsSync(runFilePath)) { + console.error(`run.py not found at path: ${runFilePath}`); + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: `Error launching job: run.py not found`, + }, + }); + return; + } + + const additionalEnv: any = { + AITK_JOB_ID: jobID, + CUDA_DEVICE_ORDER: 'PCI_BUS_ID', + CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, + IS_AI_TOOLKIT_UI: '1', + }; + + // HF_TOKEN + const hfToken = await getHFToken(); + if (hfToken && hfToken.trim() !== '') { + additionalEnv.HF_TOKEN = hfToken; + } + + // Add the --log argument to the command + const args = [runFilePath, configPath, '--log', logPath]; + + try { + let subprocess; + + if (isWindows) { + // Spawn Python directly on Windows so the process can survive parent exit + subprocess = spawn(pythonPath, args, { + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + detached: true, + windowsHide: true, + stdio: 'ignore', // don't tie stdio to parent + }); + } else { + // For non-Windows platforms, fully detach and ignore stdio so it survives daemon-like + subprocess = spawn(pythonPath, args, { + detached: true, + stdio: 'ignore', + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + }); + } + + // Save the PID to the database and a file for future management (stop/inspect) + const pid = subprocess.pid ?? null; + if (pid != null) { + await prisma.job.update({ + where: { id: jobID }, + data: { pid }, + }); + } + try { + fs.writeFileSync(path.join(trainingFolder, 'pid.txt'), String(pid ?? ''), { flag: 'w' }); + } catch (e) { + console.error('Error writing pid file:', e); + } + + // Important: let the child run independently of this Node process. + if (subprocess.unref) { + subprocess.unref(); + } + + // (No stdout/stderr listeners — logging should go to --log handled by your Python) + // (No monitoring loop — the whole point is to let it live past this worker) + } catch (error: any) { + // Handle any exceptions during process launch + console.error('Error launching process:', error); + + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: `Error launching job: ${error?.message || 'Unknown error'}`, + }, + }); + return; + } + // Resolve the promise immediately after starting the process + resolve(); + }); +}; + +export default async function startJob(jobID: string) { + const job: Job | null = await prisma.job.findUnique({ + where: { id: jobID }, + }); + if (!job) { + console.error(`Job with ID ${jobID} not found`); + return; + } + // update job status to 'running', this will run sync so we don't start multiple jobs. + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'running', + stop: false, + info: 'Starting job...', + }, + }); + // start and watch the job asynchronously so the cron can continue + startAndWatchJob(job); +} diff --git a/ai-toolkit/ui/cron/paths.ts b/ai-toolkit/ui/cron/paths.ts new file mode 100644 index 0000000000000000000000000000000000000000..ef28b973590daa36719125c1a7a3b6bbf05736f3 --- /dev/null +++ b/ai-toolkit/ui/cron/paths.ts @@ -0,0 +1,37 @@ +import path from 'path'; +import prisma from './prisma'; + +export const TOOLKIT_ROOT = path.resolve('@', '..', '..'); +export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output'); +export const defaultDatasetsFolder = path.join(TOOLKIT_ROOT, 'datasets'); +export const defaultDataRoot = path.join(TOOLKIT_ROOT, 'data'); + +console.log('TOOLKIT_ROOT:', TOOLKIT_ROOT); + +export const getTrainingFolder = async () => { + const key = 'TRAINING_FOLDER'; + let row = await prisma.settings.findFirst({ + where: { + key: key, + }, + }); + let trainingRoot = defaultTrainFolder; + if (row?.value && row.value !== '') { + trainingRoot = row.value; + } + return trainingRoot as string; +}; + +export const getHFToken = async () => { + const key = 'HF_TOKEN'; + let row = await prisma.settings.findFirst({ + where: { + key: key, + }, + }); + let token = ''; + if (row?.value && row.value !== '') { + token = row.value; + } + return token; +}; diff --git a/ai-toolkit/ui/cron/prisma.ts b/ai-toolkit/ui/cron/prisma.ts new file mode 100644 index 0000000000000000000000000000000000000000..56d96d4dc640c883cd992b035339086ec6659d8f --- /dev/null +++ b/ai-toolkit/ui/cron/prisma.ts @@ -0,0 +1,4 @@ +import { PrismaClient } from '@prisma/client'; +const prisma = new PrismaClient(); + +export default prisma; diff --git a/ai-toolkit/ui/cron/pythonPath.ts b/ai-toolkit/ui/cron/pythonPath.ts new file mode 100644 index 0000000000000000000000000000000000000000..64ac231abf64a37c58f21a28a5f224fd924cec0b --- /dev/null +++ b/ai-toolkit/ui/cron/pythonPath.ts @@ -0,0 +1,27 @@ +import path from 'path'; +import fs from 'fs'; +import { TOOLKIT_ROOT } from './paths'; + +const isWindows = process.platform === 'win32'; + +// Shared resolver used by both the cron worker and Next.js API routes +// so the Python interpreter is configured in exactly one place. +export const resolvePythonPath = (): string => { + const candidates: string[] = []; + + if (isWindows) { + candidates.push(path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe')); + candidates.push(path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe')); + } else { + candidates.push(path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python')); + candidates.push(path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python')); + } + + for (const candidate of candidates) { + if (fs.existsSync(candidate)) { + return candidate; + } + } + + return isWindows ? 'python.exe' : 'python3'; +}; diff --git a/ai-toolkit/ui/cron/worker.ts b/ai-toolkit/ui/cron/worker.ts new file mode 100644 index 0000000000000000000000000000000000000000..dd1c275d9a17455e33eb1d69ef86b34de05a12bd --- /dev/null +++ b/ai-toolkit/ui/cron/worker.ts @@ -0,0 +1,34 @@ +import processQueue from './actions/processQueue'; +class CronWorker { + interval: number; + is_running: boolean; + intervalId: NodeJS.Timeout; + constructor() { + this.interval = 1000; // Default interval of 1 second + this.is_running = false; + this.intervalId = setInterval(() => { + this.run(); + }, this.interval); + } + async run() { + if (this.is_running) { + return; + } + this.is_running = true; + try { + // Loop logic here + await this.loop(); + } catch (error) { + console.error('Error in cron worker loop:', error); + } + this.is_running = false; + } + + async loop() { + await processQueue(); + } +} + +// it automatically starts the loop +const cronWorker = new CronWorker(); +console.log('Cron worker started with interval:', cronWorker.interval, 'ms'); diff --git a/ai-toolkit/ui/next.config.ts b/ai-toolkit/ui/next.config.ts new file mode 100644 index 0000000000000000000000000000000000000000..244f1419e21e42a96b9d3e5207d45dda2cfc58f7 --- /dev/null +++ b/ai-toolkit/ui/next.config.ts @@ -0,0 +1,43 @@ +import type { NextConfig } from 'next'; +import { readFileSync } from 'fs'; +import { join } from 'path'; + +const versionFile = readFileSync(join(__dirname, '..', 'version.py'), 'utf8'); +const versionMatch = versionFile.match(/VERSION\s*=\s*["']([^"']+)["']/); +const appVersion = versionMatch ? versionMatch[1] : 'unknown'; + +const nextConfig: NextConfig = { + env: { + NEXT_PUBLIC_APP_VERSION: appVersion, + }, + serverExternalPackages: ['macstats', 'osx-temperature-sensor'], + async rewrites() { + return [ + { + source: '/proxy-8866/:path*', + destination: 'http://localhost:8866/:path*', + }, + ]; + }, + webpack: (config, { isServer }) => { + if (isServer) { + config.externals.push('osx-temperature-sensor', 'macstats'); + } + return config; + }, + devIndicators: { + buildActivity: false, + }, + typescript: { + // Remove this. Build fails because of route types + ignoreBuildErrors: true, + }, + experimental: { + serverActions: { + bodySizeLimit: '100gb', + }, + middlewareClientMaxBodySize: '100gb', + }, +}; + +export default nextConfig; diff --git a/ai-toolkit/ui/package-lock.json b/ai-toolkit/ui/package-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..09f1e71ad0c1f54e74531370b6131a1d9abd76e7 --- /dev/null +++ b/ai-toolkit/ui/package-lock.json @@ -0,0 +1,6901 @@ +{ + "name": "ai-toolkit-ui", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "ai-toolkit-ui", + "version": "0.1.0", + "dependencies": { + "@headlessui/react": "^2.2.0", + "@monaco-editor/react": "^4.7.0", + "@prisma/client": "^6.3.1", + "archiver": "^7.0.1", + "axios": "^1.7.9", + "classnames": "^2.5.1", + "lucide-react": "^0.475.0", + "next": "^15.5.9", + "node-cache": "^5.1.2", + "prisma": "^6.3.1", + "react": "^19.2.0", + "react-dom": "^19.2.0", + "react-dropzone": "^14.3.5", + "react-global-hooks": "^1.3.5", + "react-icons": "^5.5.0", + "react-select": "^5.10.1", + "react-virtuoso": "^4.18.7", + "react-zoom-pan-pinch": "^4.0.3", + "sqlite3": "^5.1.7", + "systeminformation": "^5.27.11", + "uplot": "^1.6.32", + "uuid": "^11.1.0", + "yaml": "^2.7.0" + }, + "devDependencies": { + "@types/archiver": "^6.0.3", + "@types/node": "^20", + "@types/react": "^19", + "@types/react-dom": "^19", + "concurrently": "^9.1.2", + "postcss": "^8", + "prettier": "^3.5.1", + "prettier-basic": "^1.0.0", + "tailwindcss": "^3.4.1", + "ts-node-dev": "^2.0.0", + "typescript": "^5" + }, + "optionalDependencies": { + "macstats": "^4.2.0" + } + }, + "node_modules/@alcalzone/ansi-tokenize": { + "version": "0.2.5", + "resolved": "https://registry.npmjs.org/@alcalzone/ansi-tokenize/-/ansi-tokenize-0.2.5.tgz", + "integrity": "sha512-3NX/MpTdroi0aKz134A6RC2Gb2iXVECN4QaAXnvCIxxIm3C3AVB1mkUe8NaaiyvOpDfsrqWhYtj+Q6a62RrTsw==", + "license": "MIT", + "optional": true, + "dependencies": { + "ansi-styles": "^6.2.1", + "is-fullwidth-code-point": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@alcalzone/ansi-tokenize/node_modules/is-fullwidth-code-point": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-5.1.0.tgz", + "integrity": "sha512-5XHYaSyiqADb4RnZ1Bdad6cPp8Toise4TzEjcOYDHZkTCbKgiUl7WTUCpNWHuxmDt91wnsZBc9xinNzopv3JMQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "get-east-asian-width": "^1.3.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@alloc/quick-lru": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", + "integrity": "sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@babel/code-frame": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", + "dependencies": { + "@babel/helper-validator-identifier": "^7.27.1", + "js-tokens": "^4.0.0", + "picocolors": "^1.1.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/generator": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/generator/-/generator-7.27.1.tgz", + "integrity": "sha512-UnJfnIpc/+JO0/+KRVQNGU+y5taA5vCbwN8+azkX6beii/ZF+enZJSOKo11ZSzGJjlNfJHfQtmQT8H+9TXPG2w==", + "dependencies": { + "@babel/parser": "^7.27.1", + "@babel/types": "^7.27.1", + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25", + "jsesc": "^3.0.2" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-module-imports": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-module-imports/-/helper-module-imports-7.27.1.tgz", + "integrity": "sha512-0gSFWUPNXNopqtIPQvlD5WgXYI5GY2kP2cCvoT8kczjbfcfuIljTbcWrulD1CIPIX2gt1wghbDy08yE1p+/r3w==", + "dependencies": { + "@babel/traverse": "^7.27.1", + "@babel/types": "^7.27.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-string-parser": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.27.1.tgz", + "integrity": "sha512-D2hP9eA+Sqx1kBZgzxZh0y1trbuU+JoDkiEwqhQ36nodYqJwyEIhPSdMNd7lOm/4io72luTPWH20Yda0xOuUow==", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/parser": { + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.27.2.tgz", + "integrity": "sha512-QYLs8299NA7WM/bZAdp+CviYYkVoYXlDW2rzliy3chxd1PQjej7JORuMJDJXJUb9g0TT+B99EwaVLKmX+sPXWw==", + "dependencies": { + "@babel/types": "^7.27.1" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.1.tgz", + "integrity": "sha512-1x3D2xEk2fRo3PAhwQwu5UubzgiVWSXTBfWpVd2Mx2AzRqJuDJCsgaDVZ7HB5iGzDW1Hl1sWN2mFyKjmR9uAog==", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/template": { + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", + "dependencies": { + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/traverse": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.27.1.tgz", + "integrity": "sha512-ZCYtZciz1IWJB4U61UPu4KEaqyfj+r5T1Q5mqPo+IBpcG9kHv30Z0aD8LXPgC1trYa6rK0orRyAhqUgk4MjmEg==", + "dependencies": { + "@babel/code-frame": "^7.27.1", + "@babel/generator": "^7.27.1", + "@babel/parser": "^7.27.1", + "@babel/template": "^7.27.1", + "@babel/types": "^7.27.1", + "debug": "^4.3.1", + "globals": "^11.1.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/types": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.27.1.tgz", + "integrity": "sha512-+EzkxvLNfiUeKMgy/3luqfsCWFRXLb7U6wNQTk60tovuckwB15B191tJWvpp4HjiQWdJkCxO3Wbvc6jlk3Xb2Q==", + "dependencies": { + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.27.1" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@cspotcode/source-map-support": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz", + "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==", + "dev": true, + "dependencies": { + "@jridgewell/trace-mapping": "0.3.9" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz", + "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==", + "dev": true, + "dependencies": { + "@jridgewell/resolve-uri": "^3.0.3", + "@jridgewell/sourcemap-codec": "^1.4.10" + } + }, + "node_modules/@emnapi/runtime": { + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.7.1.tgz", + "integrity": "sha512-PVtJr5CmLwYAU9PZDMITZoR5iAOShYREoR45EyyLrbntV50mdePTgUn4AmOw90Ifcj+x2kRjdzr1HP3RrNiHGA==", + "license": "MIT", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, + "node_modules/@emotion/babel-plugin": { + "version": "11.13.5", + "resolved": "https://registry.npmjs.org/@emotion/babel-plugin/-/babel-plugin-11.13.5.tgz", + "integrity": "sha512-pxHCpT2ex+0q+HH91/zsdHkw/lXd468DIN2zvfvLtPKLLMo6gQj7oLObq8PhkrxOZb/gGCq03S3Z7PDhS8pduQ==", + "dependencies": { + "@babel/helper-module-imports": "^7.16.7", + "@babel/runtime": "^7.18.3", + "@emotion/hash": "^0.9.2", + "@emotion/memoize": "^0.9.0", + "@emotion/serialize": "^1.3.3", + "babel-plugin-macros": "^3.1.0", + "convert-source-map": "^1.5.0", + "escape-string-regexp": "^4.0.0", + "find-root": "^1.1.0", + "source-map": "^0.5.7", + "stylis": "4.2.0" + } + }, + "node_modules/@emotion/cache": { + "version": "11.14.0", + "resolved": "https://registry.npmjs.org/@emotion/cache/-/cache-11.14.0.tgz", + "integrity": "sha512-L/B1lc/TViYk4DcpGxtAVbx0ZyiKM5ktoIyafGkH6zg/tj+mA+NE//aPYKG0k8kCHSHVJrpLpcAlOBEXQ3SavA==", + "dependencies": { + "@emotion/memoize": "^0.9.0", + "@emotion/sheet": "^1.4.0", + "@emotion/utils": "^1.4.2", + "@emotion/weak-memoize": "^0.4.0", + "stylis": "4.2.0" + } + }, + "node_modules/@emotion/hash": { + "version": "0.9.2", + "resolved": "https://registry.npmjs.org/@emotion/hash/-/hash-0.9.2.tgz", + "integrity": "sha512-MyqliTZGuOm3+5ZRSaaBGP3USLw6+EGykkwZns2EPC5g8jJ4z9OrdZY9apkl3+UP9+sdz76YYkwCKP5gh8iY3g==" + }, + "node_modules/@emotion/memoize": { + "version": "0.9.0", + "resolved": "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.9.0.tgz", + "integrity": "sha512-30FAj7/EoJ5mwVPOWhAyCX+FPfMDrVecJAM+Iw9NRoSl4BBAQeqj4cApHHUXOVvIPgLVDsCFoz/hGD+5QQD1GQ==" + }, + "node_modules/@emotion/react": { + "version": "11.14.0", + "resolved": "https://registry.npmjs.org/@emotion/react/-/react-11.14.0.tgz", + "integrity": "sha512-O000MLDBDdk/EohJPFUqvnp4qnHeYkVP5B0xEG0D/L7cOKP9kefu2DXn8dj74cQfsEzUqh+sr1RzFqiL1o+PpA==", + "dependencies": { + "@babel/runtime": "^7.18.3", + "@emotion/babel-plugin": "^11.13.5", + "@emotion/cache": "^11.14.0", + "@emotion/serialize": "^1.3.3", + "@emotion/use-insertion-effect-with-fallbacks": "^1.2.0", + "@emotion/utils": "^1.4.2", + "@emotion/weak-memoize": "^0.4.0", + "hoist-non-react-statics": "^3.3.1" + }, + "peerDependencies": { + "react": ">=16.8.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@emotion/serialize": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/@emotion/serialize/-/serialize-1.3.3.tgz", + "integrity": "sha512-EISGqt7sSNWHGI76hC7x1CksiXPahbxEOrC5RjmFRJTqLyEK9/9hZvBbiYn70dw4wuwMKiEMCUlR6ZXTSWQqxA==", + "dependencies": { + "@emotion/hash": "^0.9.2", + "@emotion/memoize": "^0.9.0", + "@emotion/unitless": "^0.10.0", + "@emotion/utils": "^1.4.2", + "csstype": "^3.0.2" + } + }, + "node_modules/@emotion/sheet": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@emotion/sheet/-/sheet-1.4.0.tgz", + "integrity": "sha512-fTBW9/8r2w3dXWYM4HCB1Rdp8NLibOw2+XELH5m5+AkWiL/KqYX6dc0kKYlaYyKjrQ6ds33MCdMPEwgs2z1rqg==" + }, + "node_modules/@emotion/unitless": { + "version": "0.10.0", + "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.10.0.tgz", + "integrity": "sha512-dFoMUuQA20zvtVTuxZww6OHoJYgrzfKM1t52mVySDJnMSEa08ruEvdYQbhvyu6soU+NeLVd3yKfTfT0NeV6qGg==" + }, + "node_modules/@emotion/use-insertion-effect-with-fallbacks": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@emotion/use-insertion-effect-with-fallbacks/-/use-insertion-effect-with-fallbacks-1.2.0.tgz", + "integrity": "sha512-yJMtVdH59sxi/aVJBpk9FQq+OR8ll5GT8oWd57UpeaKEVGab41JWaCFA7FRLoMLloOZF/c/wsPoe+bfGmRKgDg==", + "peerDependencies": { + "react": ">=16.8.0" + } + }, + "node_modules/@emotion/utils": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/@emotion/utils/-/utils-1.4.2.tgz", + "integrity": "sha512-3vLclRofFziIa3J2wDh9jjbkUz9qk5Vi3IZ/FSTKViB0k+ef0fPV7dYrUIugbgupYDx7v9ud/SjrtEP8Y4xLoA==" + }, + "node_modules/@emotion/weak-memoize": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/@emotion/weak-memoize/-/weak-memoize-0.4.0.tgz", + "integrity": "sha512-snKqtPW01tN0ui7yu9rGv69aJXr/a/Ywvl11sUjNtEcRc+ng/mQriFL0wLXMef74iHa/EkftbDzU9F8iFbH+zg==" + }, + "node_modules/@floating-ui/core": { + "version": "1.6.9", + "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.6.9.tgz", + "integrity": "sha512-uMXCuQ3BItDUbAMhIXw7UPXRfAlOAvZzdK9BWpE60MCn+Svt3aLn9jsPTi/WNGlRUu2uI0v5S7JiIUsbsvh3fw==", + "dependencies": { + "@floating-ui/utils": "^0.2.9" + } + }, + "node_modules/@floating-ui/dom": { + "version": "1.6.13", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.6.13.tgz", + "integrity": "sha512-umqzocjDgNRGTuO7Q8CU32dkHkECqI8ZdMZ5Swb6QAM0t5rnlrN3lGo1hdpscRd3WS8T6DKYK4ephgIH9iRh3w==", + "dependencies": { + "@floating-ui/core": "^1.6.0", + "@floating-ui/utils": "^0.2.9" + } + }, + "node_modules/@floating-ui/react": { + "version": "0.26.28", + "resolved": "https://registry.npmjs.org/@floating-ui/react/-/react-0.26.28.tgz", + "integrity": "sha512-yORQuuAtVpiRjpMhdc0wJj06b9JFjrYF4qp96j++v2NBpbi6SEGF7donUJ3TMieerQ6qVkAv1tgr7L4r5roTqw==", + "dependencies": { + "@floating-ui/react-dom": "^2.1.2", + "@floating-ui/utils": "^0.2.8", + "tabbable": "^6.0.0" + }, + "peerDependencies": { + "react": ">=16.8.0", + "react-dom": ">=16.8.0" + } + }, + "node_modules/@floating-ui/react-dom": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.2.tgz", + "integrity": "sha512-06okr5cgPzMNBy+Ycse2A6udMi4bqwW/zgBF/rwjcNqWkyr82Mcg8b0vjX8OJpZFy/FKjJmw6wV7t44kK6kW7A==", + "dependencies": { + "@floating-ui/dom": "^1.0.0" + }, + "peerDependencies": { + "react": ">=16.8.0", + "react-dom": ">=16.8.0" + } + }, + "node_modules/@floating-ui/utils": { + "version": "0.2.9", + "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.9.tgz", + "integrity": "sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==" + }, + "node_modules/@gar/promisify": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.3.tgz", + "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", + "license": "MIT", + "optional": true + }, + "node_modules/@headlessui/react": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@headlessui/react/-/react-2.2.0.tgz", + "integrity": "sha512-RzCEg+LXsuI7mHiSomsu/gBJSjpupm6A1qIZ5sWjd7JhARNlMiSA4kKfJpCKwU9tE+zMRterhhrP74PvfJrpXQ==", + "dependencies": { + "@floating-ui/react": "^0.26.16", + "@react-aria/focus": "^3.17.1", + "@react-aria/interactions": "^3.21.3", + "@tanstack/react-virtual": "^3.8.1" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "react": "^18 || ^19 || ^19.0.0-rc", + "react-dom": "^18 || ^19 || ^19.0.0-rc" + } + }, + "node_modules/@img/colour": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@img/colour/-/colour-1.0.0.tgz", + "integrity": "sha512-A5P/LfWGFSl6nsckYtjw9da+19jB8hkJ6ACTGcDfEJ0aE+l2n2El7dsVM7UVHZQ9s2lmYMWlrS21YLy2IR1LUw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18" + } + }, + "node_modules/@img/sharp-darwin-arm64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-darwin-arm64/-/sharp-darwin-arm64-0.34.5.tgz", + "integrity": "sha512-imtQ3WMJXbMY4fxb/Ndp6HBTNVtWCUI0WdobyheGf5+ad6xX8VIDO8u2xE4qc/fr08CKG/7dDseFtn6M6g/r3w==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-darwin-arm64": "1.2.4" + } + }, + "node_modules/@img/sharp-darwin-x64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-darwin-x64/-/sharp-darwin-x64-0.34.5.tgz", + "integrity": "sha512-YNEFAF/4KQ/PeW0N+r+aVVsoIY0/qxxikF2SWdp+NRkmMB7y9LBZAVqQ4yhGCm/H3H270OSykqmQMKLBhBJDEw==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-darwin-x64": "1.2.4" + } + }, + "node_modules/@img/sharp-libvips-darwin-arm64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-arm64/-/sharp-libvips-darwin-arm64-1.2.4.tgz", + "integrity": "sha512-zqjjo7RatFfFoP0MkQ51jfuFZBnVE2pRiaydKJ1G/rHZvnsrHAOcQALIi9sA5co5xenQdTugCvtb1cuf78Vf4g==", + "cpu": [ + "arm64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "darwin" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-darwin-x64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-x64/-/sharp-libvips-darwin-x64-1.2.4.tgz", + "integrity": "sha512-1IOd5xfVhlGwX+zXv2N93k0yMONvUlANylbJw1eTah8K/Jtpi15KC+WSiaX/nBmbm2HxRM1gZ0nSdjSsrZbGKg==", + "cpu": [ + "x64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "darwin" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-arm": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm/-/sharp-libvips-linux-arm-1.2.4.tgz", + "integrity": "sha512-bFI7xcKFELdiNCVov8e44Ia4u2byA+l3XtsAj+Q8tfCwO6BQ8iDojYdvoPMqsKDkuoOo+X6HZA0s0q11ANMQ8A==", + "cpu": [ + "arm" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-arm64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm64/-/sharp-libvips-linux-arm64-1.2.4.tgz", + "integrity": "sha512-excjX8DfsIcJ10x1Kzr4RcWe1edC9PquDRRPx3YVCvQv+U5p7Yin2s32ftzikXojb1PIFc/9Mt28/y+iRklkrw==", + "cpu": [ + "arm64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-ppc64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-ppc64/-/sharp-libvips-linux-ppc64-1.2.4.tgz", + "integrity": "sha512-FMuvGijLDYG6lW+b/UvyilUWu5Ayu+3r2d1S8notiGCIyYU/76eig1UfMmkZ7vwgOrzKzlQbFSuQfgm7GYUPpA==", + "cpu": [ + "ppc64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-riscv64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-riscv64/-/sharp-libvips-linux-riscv64-1.2.4.tgz", + "integrity": "sha512-oVDbcR4zUC0ce82teubSm+x6ETixtKZBh/qbREIOcI3cULzDyb18Sr/Wcyx7NRQeQzOiHTNbZFF1UwPS2scyGA==", + "cpu": [ + "riscv64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-s390x": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-s390x/-/sharp-libvips-linux-s390x-1.2.4.tgz", + "integrity": "sha512-qmp9VrzgPgMoGZyPvrQHqk02uyjA0/QrTO26Tqk6l4ZV0MPWIW6LTkqOIov+J1yEu7MbFQaDpwdwJKhbJvuRxQ==", + "cpu": [ + "s390x" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-x64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-x64/-/sharp-libvips-linux-x64-1.2.4.tgz", + "integrity": "sha512-tJxiiLsmHc9Ax1bz3oaOYBURTXGIRDODBqhveVHonrHJ9/+k89qbLl0bcJns+e4t4rvaNBxaEZsFtSfAdquPrw==", + "cpu": [ + "x64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linuxmusl-arm64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-arm64/-/sharp-libvips-linuxmusl-arm64-1.2.4.tgz", + "integrity": "sha512-FVQHuwx1IIuNow9QAbYUzJ+En8KcVm9Lk5+uGUQJHaZmMECZmOlix9HnH7n1TRkXMS0pGxIJokIVB9SuqZGGXw==", + "cpu": [ + "arm64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linuxmusl-x64": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-x64/-/sharp-libvips-linuxmusl-x64-1.2.4.tgz", + "integrity": "sha512-+LpyBk7L44ZIXwz/VYfglaX/okxezESc6UxDSoyo2Ks6Jxc4Y7sGjpgU9s4PMgqgjj1gZCylTieNamqA1MF7Dg==", + "cpu": [ + "x64" + ], + "license": "LGPL-3.0-or-later", + "optional": true, + "os": [ + "linux" + ], + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-linux-arm": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm/-/sharp-linux-arm-0.34.5.tgz", + "integrity": "sha512-9dLqsvwtg1uuXBGZKsxem9595+ujv0sJ6Vi8wcTANSFpwV/GONat5eCkzQo/1O6zRIkh0m/8+5BjrRr7jDUSZw==", + "cpu": [ + "arm" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-arm": "1.2.4" + } + }, + "node_modules/@img/sharp-linux-arm64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm64/-/sharp-linux-arm64-0.34.5.tgz", + "integrity": "sha512-bKQzaJRY/bkPOXyKx5EVup7qkaojECG6NLYswgktOZjaXecSAeCWiZwwiFf3/Y+O1HrauiE3FVsGxFg8c24rZg==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-arm64": "1.2.4" + } + }, + "node_modules/@img/sharp-linux-ppc64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-ppc64/-/sharp-linux-ppc64-0.34.5.tgz", + "integrity": "sha512-7zznwNaqW6YtsfrGGDA6BRkISKAAE1Jo0QdpNYXNMHu2+0dTrPflTLNkpc8l7MUP5M16ZJcUvysVWWrMefZquA==", + "cpu": [ + "ppc64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-ppc64": "1.2.4" + } + }, + "node_modules/@img/sharp-linux-riscv64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-riscv64/-/sharp-linux-riscv64-0.34.5.tgz", + "integrity": "sha512-51gJuLPTKa7piYPaVs8GmByo7/U7/7TZOq+cnXJIHZKavIRHAP77e3N2HEl3dgiqdD/w0yUfiJnII77PuDDFdw==", + "cpu": [ + "riscv64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-riscv64": "1.2.4" + } + }, + "node_modules/@img/sharp-linux-s390x": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-s390x/-/sharp-linux-s390x-0.34.5.tgz", + "integrity": "sha512-nQtCk0PdKfho3eC5MrbQoigJ2gd1CgddUMkabUj+rBevs8tZ2cULOx46E7oyX+04WGfABgIwmMC0VqieTiR4jg==", + "cpu": [ + "s390x" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-s390x": "1.2.4" + } + }, + "node_modules/@img/sharp-linux-x64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-x64/-/sharp-linux-x64-0.34.5.tgz", + "integrity": "sha512-MEzd8HPKxVxVenwAa+JRPwEC7QFjoPWuS5NZnBt6B3pu7EG2Ge0id1oLHZpPJdn3OQK+BQDiw9zStiHBTJQQQQ==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-x64": "1.2.4" + } + }, + "node_modules/@img/sharp-linuxmusl-arm64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-arm64/-/sharp-linuxmusl-arm64-0.34.5.tgz", + "integrity": "sha512-fprJR6GtRsMt6Kyfq44IsChVZeGN97gTD331weR1ex1c1rypDEABN6Tm2xa1wE6lYb5DdEnk03NZPqA7Id21yg==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linuxmusl-arm64": "1.2.4" + } + }, + "node_modules/@img/sharp-linuxmusl-x64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-x64/-/sharp-linuxmusl-x64-0.34.5.tgz", + "integrity": "sha512-Jg8wNT1MUzIvhBFxViqrEhWDGzqymo3sV7z7ZsaWbZNDLXRJZoRGrjulp60YYtV4wfY8VIKcWidjojlLcWrd8Q==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linuxmusl-x64": "1.2.4" + } + }, + "node_modules/@img/sharp-wasm32": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-wasm32/-/sharp-wasm32-0.34.5.tgz", + "integrity": "sha512-OdWTEiVkY2PHwqkbBI8frFxQQFekHaSSkUIJkwzclWZe64O1X4UlUjqqqLaPbUpMOQk6FBu/HtlGXNblIs0huw==", + "cpu": [ + "wasm32" + ], + "license": "Apache-2.0 AND LGPL-3.0-or-later AND MIT", + "optional": true, + "dependencies": { + "@emnapi/runtime": "^1.7.0" + }, + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-win32-arm64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-win32-arm64/-/sharp-win32-arm64-0.34.5.tgz", + "integrity": "sha512-WQ3AgWCWYSb2yt+IG8mnC6Jdk9Whs7O0gxphblsLvdhSpSTtmu69ZG1Gkb6NuvxsNACwiPV6cNSZNzt0KPsw7g==", + "cpu": [ + "arm64" + ], + "license": "Apache-2.0 AND LGPL-3.0-or-later", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-win32-ia32": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-win32-ia32/-/sharp-win32-ia32-0.34.5.tgz", + "integrity": "sha512-FV9m/7NmeCmSHDD5j4+4pNI8Cp3aW+JvLoXcTUo0IqyjSfAZJ8dIUmijx1qaJsIiU+Hosw6xM5KijAWRJCSgNg==", + "cpu": [ + "ia32" + ], + "license": "Apache-2.0 AND LGPL-3.0-or-later", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-win32-x64": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/@img/sharp-win32-x64/-/sharp-win32-x64-0.34.5.tgz", + "integrity": "sha512-+29YMsqY2/9eFEiW93eqWnuLcWcufowXewwSNIT6UwZdUUCrM3oFjMWH/Z6/TMmb4hlFenmfAVbpWeup2jryCw==", + "cpu": [ + "x64" + ], + "license": "Apache-2.0 AND LGPL-3.0-or-later", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@isaacs/cliui": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", + "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", + "dependencies": { + "string-width": "^5.1.2", + "string-width-cjs": "npm:string-width@^4.2.0", + "strip-ansi": "^7.0.1", + "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", + "wrap-ansi": "^8.1.0", + "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.8", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", + "integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==", + "dependencies": { + "@jridgewell/set-array": "^1.2.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/set-array": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@monaco-editor/loader": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.5.0.tgz", + "integrity": "sha512-hKoGSM+7aAc7eRTRjpqAZucPmoNOC4UUbknb/VNoTkEIkCPhqV8LfbsgM1webRM7S/z21eHEx9Fkwx8Z/C/+Xw==", + "dependencies": { + "state-local": "^1.0.6" + } + }, + "node_modules/@monaco-editor/react": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/@monaco-editor/react/-/react-4.7.0.tgz", + "integrity": "sha512-cyzXQCtO47ydzxpQtCGSQGOC8Gk3ZUeBXFAxD+CWXYFo5OqZyZUonFl0DwUlTyAfRHntBfw2p3w4s9R6oe1eCA==", + "dependencies": { + "@monaco-editor/loader": "^1.5.0" + }, + "peerDependencies": { + "monaco-editor": ">= 0.25.0 < 1", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/@next/env": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/env/-/env-15.5.19.tgz", + "integrity": "sha512-sWWluFvcv5v3Fxznmf2ZfjyoVQt/64oCnYqS90inQWGzMPK1VjvekPiz3OPHKmFT30EnHrjlbyaHLt3M0vWabw==", + "license": "MIT" + }, + "node_modules/@next/swc-darwin-arm64": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.5.19.tgz", + "integrity": "sha512-jx9wWlTKueHKPvVOndyr7WuaevWCkuYqsQ8gC0TMPKAVWG3MhcdMrjfo9tvIZNXd0QOUYXXvAcZ325y8Uq7uzg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-darwin-x64": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.5.19.tgz", + "integrity": "sha512-291KFcsIQ3OenRdiUDFOR6W3wezzH4auENXm1gbm1Bjd4ANMMRgxPrWTUztQN43BnVoVuMnHCrLeECIMwgFKbA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-gnu": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.5.19.tgz", + "integrity": "sha512-WeH+nelQyyMeE2f8FxBRZNrGipya5zHZV2vjzfCOAYyiI6am+NbnWAAldOBFQBB2w0DjJcsvrKqoFT2b7+5YoA==", + "cpu": [ + "arm64" + ], + "libc": [ + "glibc" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-musl": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.5.19.tgz", + "integrity": "sha512-5xTOE0lDlDCSSfp+BAif7j17VRRCjWp//ZPZy6NI0QpdrhxtQnsZguSx0xAAZ0c9XZLrLLwCe/XVe5YPrRilKw==", + "cpu": [ + "arm64" + ], + "libc": [ + "musl" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-gnu": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.5.19.tgz", + "integrity": "sha512-LTxRmMgqqMv05Had879W00Fm53quiJd3Zuz8h1JSNJ3nGSlbZ/7Tjs1tKyScgN3Au3t3MyPsjPlq60fMmSHLsg==", + "cpu": [ + "x64" + ], + "libc": [ + "glibc" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-musl": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.5.19.tgz", + "integrity": "sha512-eoNQSpA5PQfB9wBO4RA47MTDXWz1fizy9Y3Z6e4DetYIF3dvjuu8sj7aIGn/bFCU6lnFzTK34NtCaffP4NsQ7Q==", + "cpu": [ + "x64" + ], + "libc": [ + "musl" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-arm64-msvc": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.5.19.tgz", + "integrity": "sha512-6UNt2dFuCHOe446sm/Kp69nUe8/wIhnh9bm6Xcqw4qEWCOppLMOvhTBVgvM7invVUNr4SPpP6NOQsACtn2IN9Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-x64-msvc": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.5.19.tgz", + "integrity": "sha512-PhmojAHyqMne56HBLGu9dhDnHPuFmEjrXSQMM/nW0J6j849lk3ESrVtqNJcCk8CKOV7brpTTbaYAjwKPzKM69w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@npmcli/fs": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-1.1.1.tgz", + "integrity": "sha512-8KG5RD0GVP4ydEzRn/I4BNDuxDtqVbOdm8675T49OIG/NGhaK0pjPX7ZcDlvKYbA+ulvVK3ztfcF4uBdOxuJbQ==", + "license": "ISC", + "optional": true, + "dependencies": { + "@gar/promisify": "^1.0.1", + "semver": "^7.3.5" + } + }, + "node_modules/@npmcli/move-file": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-1.1.2.tgz", + "integrity": "sha512-1SUf/Cg2GzGDyaf15aR9St9TWlb+XvbZXWpDx8YKs7MLzMH/BCeopv+y9vzrzgkfykCGuWOlSu3mZhj2+FQcrg==", + "deprecated": "This functionality has been moved to @npmcli/fs", + "license": "MIT", + "optional": true, + "dependencies": { + "mkdirp": "^1.0.4", + "rimraf": "^3.0.2" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/@pkgjs/parseargs": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", + "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", + "optional": true, + "engines": { + "node": ">=14" + } + }, + "node_modules/@prisma/client": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/@prisma/client/-/client-6.3.1.tgz", + "integrity": "sha512-ARAJaPs+eBkemdky/XU3cvGRl+mIPHCN2lCXsl5Vlb0E2gV+R6IN7aCI8CisRGszEZondwIsW9Iz8EJkTdykyA==", + "hasInstallScript": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "peerDependencies": { + "prisma": "*", + "typescript": ">=5.1.0" + }, + "peerDependenciesMeta": { + "prisma": { + "optional": true + }, + "typescript": { + "optional": true + } + } + }, + "node_modules/@prisma/debug": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/@prisma/debug/-/debug-6.3.1.tgz", + "integrity": "sha512-RrEBkd+HLZx+ydfmYT0jUj7wjLiS95wfTOSQ+8FQbvb6vHh5AeKfEPt/XUQ5+Buljj8hltEfOslEW57/wQIVeA==", + "license": "Apache-2.0" + }, + "node_modules/@prisma/engines": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/@prisma/engines/-/engines-6.3.1.tgz", + "integrity": "sha512-sXdqEVLyGAJ5/iUoG/Ea5AdHMN71m6PzMBWRQnLmhhOejzqAaEr8rUd623ql6OJpED4s/U4vIn4dg1qkF7vGag==", + "hasInstallScript": true, + "license": "Apache-2.0", + "dependencies": { + "@prisma/debug": "6.3.1", + "@prisma/engines-version": "6.3.0-17.acc0b9dd43eb689cbd20c9470515d719db10d0b0", + "@prisma/fetch-engine": "6.3.1", + "@prisma/get-platform": "6.3.1" + } + }, + "node_modules/@prisma/engines-version": { + "version": "6.3.0-17.acc0b9dd43eb689cbd20c9470515d719db10d0b0", + "resolved": "https://registry.npmjs.org/@prisma/engines-version/-/engines-version-6.3.0-17.acc0b9dd43eb689cbd20c9470515d719db10d0b0.tgz", + "integrity": "sha512-R/ZcMuaWZT2UBmgX3Ko6PAV3f8//ZzsjRIG1eKqp3f2rqEqVtCv+mtzuH2rBPUC9ujJ5kCb9wwpxeyCkLcHVyA==", + "license": "Apache-2.0" + }, + "node_modules/@prisma/fetch-engine": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/@prisma/fetch-engine/-/fetch-engine-6.3.1.tgz", + "integrity": "sha512-HOf/0umOgt+/S2xtZze+FHKoxpVg4YpVxROr6g2YG09VsI3Ipyb+rGvD6QGbCqkq5NTWAAZoOGNL+oy7t+IhaQ==", + "license": "Apache-2.0", + "dependencies": { + "@prisma/debug": "6.3.1", + "@prisma/engines-version": "6.3.0-17.acc0b9dd43eb689cbd20c9470515d719db10d0b0", + "@prisma/get-platform": "6.3.1" + } + }, + "node_modules/@prisma/get-platform": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/@prisma/get-platform/-/get-platform-6.3.1.tgz", + "integrity": "sha512-AYLq6Hk9xG73JdLWJ3Ip9Wg/vlP7xPvftGBalsPzKDOHr/ImhwJ09eS8xC2vNT12DlzGxhfk8BkL0ve2OriNhQ==", + "license": "Apache-2.0", + "dependencies": { + "@prisma/debug": "6.3.1" + } + }, + "node_modules/@react-aria/focus": { + "version": "3.19.1", + "resolved": "https://registry.npmjs.org/@react-aria/focus/-/focus-3.19.1.tgz", + "integrity": "sha512-bix9Bu1Ue7RPcYmjwcjhB14BMu2qzfJ3tMQLqDc9pweJA66nOw8DThy3IfVr8Z7j2PHktOLf9kcbiZpydKHqzg==", + "dependencies": { + "@react-aria/interactions": "^3.23.0", + "@react-aria/utils": "^3.27.0", + "@react-types/shared": "^3.27.0", + "@swc/helpers": "^0.5.0", + "clsx": "^2.0.0" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1", + "react-dom": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1" + } + }, + "node_modules/@react-aria/interactions": { + "version": "3.23.0", + "resolved": "https://registry.npmjs.org/@react-aria/interactions/-/interactions-3.23.0.tgz", + "integrity": "sha512-0qR1atBIWrb7FzQ+Tmr3s8uH5mQdyRH78n0krYaG8tng9+u1JlSi8DGRSaC9ezKyNB84m7vHT207xnHXGeJ3Fg==", + "dependencies": { + "@react-aria/ssr": "^3.9.7", + "@react-aria/utils": "^3.27.0", + "@react-types/shared": "^3.27.0", + "@swc/helpers": "^0.5.0" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1", + "react-dom": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1" + } + }, + "node_modules/@react-aria/ssr": { + "version": "3.9.7", + "resolved": "https://registry.npmjs.org/@react-aria/ssr/-/ssr-3.9.7.tgz", + "integrity": "sha512-GQygZaGlmYjmYM+tiNBA5C6acmiDWF52Nqd40bBp0Znk4M4hP+LTmI0lpI1BuKMw45T8RIhrAsICIfKwZvi2Gg==", + "dependencies": { + "@swc/helpers": "^0.5.0" + }, + "engines": { + "node": ">= 12" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1" + } + }, + "node_modules/@react-aria/utils": { + "version": "3.27.0", + "resolved": "https://registry.npmjs.org/@react-aria/utils/-/utils-3.27.0.tgz", + "integrity": "sha512-p681OtApnKOdbeN8ITfnnYqfdHS0z7GE+4l8EXlfLnr70Rp/9xicBO6d2rU+V/B3JujDw2gPWxYKEnEeh0CGCw==", + "dependencies": { + "@react-aria/ssr": "^3.9.7", + "@react-stately/utils": "^3.10.5", + "@react-types/shared": "^3.27.0", + "@swc/helpers": "^0.5.0", + "clsx": "^2.0.0" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1", + "react-dom": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1" + } + }, + "node_modules/@react-stately/utils": { + "version": "3.10.5", + "resolved": "https://registry.npmjs.org/@react-stately/utils/-/utils-3.10.5.tgz", + "integrity": "sha512-iMQSGcpaecghDIh3mZEpZfoFH3ExBwTtuBEcvZ2XnGzCgQjeYXcMdIUwAfVQLXFTdHUHGF6Gu6/dFrYsCzySBQ==", + "dependencies": { + "@swc/helpers": "^0.5.0" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1" + } + }, + "node_modules/@react-types/shared": { + "version": "3.27.0", + "resolved": "https://registry.npmjs.org/@react-types/shared/-/shared-3.27.0.tgz", + "integrity": "sha512-gvznmLhi6JPEf0bsq7SwRYTHAKKq/wcmKqFez9sRdbED+SPMUmK5omfZ6w3EwUFQHbYUa4zPBYedQ7Knv70RMw==", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1" + } + }, + "node_modules/@swc/helpers": { + "version": "0.5.15", + "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.15.tgz", + "integrity": "sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==", + "dependencies": { + "tslib": "^2.8.0" + } + }, + "node_modules/@tanstack/react-virtual": { + "version": "3.13.0", + "resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.13.0.tgz", + "integrity": "sha512-CchF0NlLIowiM2GxtsoKBkXA4uqSnY2KvnXo+kyUFD4a4ll6+J0qzoRsUPMwXV/H26lRsxgJIr/YmjYum2oEjg==", + "dependencies": { + "@tanstack/virtual-core": "3.13.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/@tanstack/virtual-core": { + "version": "3.13.0", + "resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.13.0.tgz", + "integrity": "sha512-NBKJP3OIdmZY3COJdWkSonr50FMVIi+aj5ZJ7hI/DTpEKg2RMfo/KvP8A3B/zOSpMgIe52B5E2yn7rryULzA6g==", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, + "node_modules/@tootallnate/once": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@tootallnate/once/-/once-1.1.2.tgz", + "integrity": "sha512-RbzJvlNzmRq5c3O09UipeuXno4tA1FE6ikOjxZK0tuxVv3412l64l5t1W5pj4+rJq9vpkm/kwiR07aZXnsKPxw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/@tsconfig/node10": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.11.tgz", + "integrity": "sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==", + "dev": true + }, + "node_modules/@tsconfig/node12": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz", + "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==", + "dev": true + }, + "node_modules/@tsconfig/node14": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz", + "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==", + "dev": true + }, + "node_modules/@tsconfig/node16": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.4.tgz", + "integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==", + "dev": true + }, + "node_modules/@types/archiver": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/@types/archiver/-/archiver-6.0.3.tgz", + "integrity": "sha512-a6wUll6k3zX6qs5KlxIggs1P1JcYJaTCx2gnlr+f0S1yd2DoaEwoIK10HmBaLnZwWneBz+JBm0dwcZu0zECBcQ==", + "dev": true, + "dependencies": { + "@types/readdir-glob": "*" + } + }, + "node_modules/@types/node": { + "version": "20.17.19", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.19.tgz", + "integrity": "sha512-LEwC7o1ifqg/6r2gn9Dns0f1rhK+fPFDoMiceTJ6kWmVk6bgXBI/9IOWfVan4WiAavK9pIVWdX0/e3J+eEUh5A==", + "dev": true, + "dependencies": { + "undici-types": "~6.19.2" + } + }, + "node_modules/@types/parse-json": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/parse-json/-/parse-json-4.0.2.tgz", + "integrity": "sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==" + }, + "node_modules/@types/react": { + "version": "19.0.10", + "resolved": "https://registry.npmjs.org/@types/react/-/react-19.0.10.tgz", + "integrity": "sha512-JuRQ9KXLEjaUNjTWpzuR231Z2WpIwczOkBEIvbHNCzQefFIT0L8IqE6NV6ULLyC1SI/i234JnDoMkfg+RjQj2g==", + "dependencies": { + "csstype": "^3.0.2" + } + }, + "node_modules/@types/react-dom": { + "version": "19.0.4", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.0.4.tgz", + "integrity": "sha512-4fSQ8vWFkg+TGhePfUzVmat3eC14TXYSsiiDSLI0dVLsrm9gZFABjPy/Qu6TKgl1tq1Bu1yDsuQgY3A3DOjCcg==", + "dev": true, + "peerDependencies": { + "@types/react": "^19.0.0" + } + }, + "node_modules/@types/react-transition-group": { + "version": "4.4.12", + "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.12.tgz", + "integrity": "sha512-8TV6R3h2j7a91c+1DXdJi3Syo69zzIZbz7Lg5tORM5LEJG7X/E6a1V3drRyBRZq7/utz7A+c4OgYLiLcYGHG6w==", + "peerDependencies": { + "@types/react": "*" + } + }, + "node_modules/@types/readdir-glob": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/@types/readdir-glob/-/readdir-glob-1.1.5.tgz", + "integrity": "sha512-raiuEPUYqXu+nvtY2Pe8s8FEmZ3x5yAH4VkLdihcPdalvsHltomrRC9BzuStrJ9yk06470hS0Crw0f1pXqD+Hg==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@types/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha512-xevGOReSYGM7g/kUBZzPqCrR/KYAo+F0yiPc85WFTJa0MSLtyFTVTU6cJu/aV4mid7IffDIWqo69THF2o4JiEQ==", + "dev": true + }, + "node_modules/@types/strip-json-comments": { + "version": "0.0.30", + "resolved": "https://registry.npmjs.org/@types/strip-json-comments/-/strip-json-comments-0.0.30.tgz", + "integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==", + "dev": true + }, + "node_modules/abbrev": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz", + "integrity": "sha512-nne9/IiQ/hzIhY6pdDnbBtz7DjPTKrY00P/zvPSm5pOFkl6xuGrGnXn/VtTNNfNtAfZ9/1RtehkszU9qcTii0Q==", + "license": "ISC", + "optional": true + }, + "node_modules/abort-controller": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", + "integrity": "sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==", + "dependencies": { + "event-target-shim": "^5.0.0" + }, + "engines": { + "node": ">=6.5" + } + }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-walk": { + "version": "8.3.4", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz", + "integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==", + "dev": true, + "dependencies": { + "acorn": "^8.11.0" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/agent-base": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz", + "integrity": "sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ==", + "license": "MIT", + "dependencies": { + "debug": "4" + }, + "engines": { + "node": ">= 6.0.0" + } + }, + "node_modules/agentkeepalive": { + "version": "4.6.0", + "resolved": "https://registry.npmjs.org/agentkeepalive/-/agentkeepalive-4.6.0.tgz", + "integrity": "sha512-kja8j7PjmncONqaTsB8fQ+wE2mSU2DJ9D4XKoJ5PFWIdRMa6SLSN1ff4mOr4jCbfRSsxR4keIiySJU0N9T5hIQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "humanize-ms": "^1.2.1" + }, + "engines": { + "node": ">= 8.0.0" + } + }, + "node_modules/aggregate-error": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/aggregate-error/-/aggregate-error-3.1.0.tgz", + "integrity": "sha512-4I7Td01quW/RpocfNayFdFVk1qSuoh0E7JrbRJ16nH01HhKFQ88INq9Sd+nd72zqRySlr9BmDA8xlEJ6vJMrYA==", + "license": "MIT", + "optional": true, + "dependencies": { + "clean-stack": "^2.0.0", + "indent-string": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-escapes": { + "version": "7.3.0", + "resolved": "https://registry.npmjs.org/ansi-escapes/-/ansi-escapes-7.3.0.tgz", + "integrity": "sha512-BvU8nYgGQBxcmMuEeUEmNTvrMVjJNSH7RgW24vXexN4Ven6qCvy4TntnvlnwnMLTVlcRQQdbRY8NKnaIoeWDNg==", + "license": "MIT", + "optional": true, + "dependencies": { + "environment": "^1.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ansi-regex": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz", + "integrity": "sha512-Bq3SmSpyFHaWjPk8If9yc6svM8c56dB5BAtW4Qbw5jHTwwXXcTLoRMkpDJp6VL0XzlWaCHTXrkFURMYmD0sLqg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/ansi-styles": { + "version": "6.2.3", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.3.tgz", + "integrity": "sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/any-promise": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", + "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==", + "dev": true + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/aproba": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/aproba/-/aproba-2.0.0.tgz", + "integrity": "sha512-lYe4Gx7QT+MKGbDsA+Z+he/Wtef0BiwDOlK/XkBrdfsh9J/jPPXbX0tE9x9cl27Tmu5gg3QUbUrQYa/y+KOHPQ==", + "license": "ISC", + "optional": true + }, + "node_modules/archiver": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/archiver/-/archiver-7.0.1.tgz", + "integrity": "sha512-ZcbTaIqJOfCc03QwD468Unz/5Ir8ATtvAHsK+FdXbDIbGfihqh9mrvdcYunQzqn4HrvWWaFyaxJhGZagaJJpPQ==", + "dependencies": { + "archiver-utils": "^5.0.2", + "async": "^3.2.4", + "buffer-crc32": "^1.0.0", + "readable-stream": "^4.0.0", + "readdir-glob": "^1.1.2", + "tar-stream": "^3.0.0", + "zip-stream": "^6.0.1" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/archiver-utils": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/archiver-utils/-/archiver-utils-5.0.2.tgz", + "integrity": "sha512-wuLJMmIBQYCsGZgYLTy5FIB2pF6Lfb6cXMSF8Qywwk3t20zWnAi7zLcQFdKQmIB8wyZpY5ER38x08GbwtR2cLA==", + "dependencies": { + "glob": "^10.0.0", + "graceful-fs": "^4.2.0", + "is-stream": "^2.0.1", + "lazystream": "^1.0.0", + "lodash": "^4.17.15", + "normalize-path": "^3.0.0", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/archiver-utils/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/archiver-utils/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/archiver/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/archiver/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/archiver/node_modules/tar-stream": { + "version": "3.1.7", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.7.tgz", + "integrity": "sha512-qJj60CXt7IU1Ffyc3NJMjh6EkuCFej46zUqJ4J7pqYlThyd9bO0XBTmcOIhSzZJVWfsLks0+nle/j538YAW9RQ==", + "dependencies": { + "b4a": "^1.6.4", + "fast-fifo": "^1.2.0", + "streamx": "^2.15.0" + } + }, + "node_modules/are-we-there-yet": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/are-we-there-yet/-/are-we-there-yet-3.0.1.tgz", + "integrity": "sha512-QZW4EDmGwlYur0Yyf/b2uGucHQMa8aFUP7eu9ddR73vvhFyt4V0Vl3QHPcTNJ8l6qYOBdxgXdnBXQrHilfRQBg==", + "deprecated": "This package is no longer supported.", + "license": "ISC", + "optional": true, + "dependencies": { + "delegates": "^1.0.0", + "readable-stream": "^3.6.0" + }, + "engines": { + "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + } + }, + "node_modules/arg": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/arg/-/arg-5.0.2.tgz", + "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==", + "dev": true + }, + "node_modules/async": { + "version": "3.2.6", + "resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz", + "integrity": "sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==" + }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, + "node_modules/attr-accept": { + "version": "2.2.5", + "resolved": "https://registry.npmjs.org/attr-accept/-/attr-accept-2.2.5.tgz", + "integrity": "sha512-0bDNnY/u6pPwHDMoF0FieU354oBi0a8rD9FcsLwzcGWbc8KS8KPIi7y+s13OlVY+gMWc/9xEMUgNE6Qm8ZllYQ==", + "engines": { + "node": ">=4" + } + }, + "node_modules/auto-bind": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/auto-bind/-/auto-bind-5.0.1.tgz", + "integrity": "sha512-ooviqdwwgfIfNmDwo94wlshcdzfO64XV0Cg6oDsDYBJfITDz1EngD2z7DkbvCWn+XIMsIqW27sEVF6qcpJrRcg==", + "license": "MIT", + "optional": true, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/axios": { + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.17.0.tgz", + "integrity": "sha512-J8SwNxprqqpbfenehxWYXE7CW+wM1BB4w3+N+g+/Wx40xM4rsLrfPmHHxSWIxJLYDgSY/HqlFPIYb2/S3rxafw==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.16.0", + "form-data": "^4.0.5", + "https-proxy-agent": "^5.0.1", + "proxy-from-env": "^2.1.0" + } + }, + "node_modules/b4a": { + "version": "1.6.7", + "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.7.tgz", + "integrity": "sha512-OnAYlL5b7LEkALw87fUVafQw5rVR9RjwGd4KUwNQ6DrrNmaVaUCgLipfVlzrPQ4tWOR9P0IXGNOx50jYCCdSJg==" + }, + "node_modules/babel-plugin-macros": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/babel-plugin-macros/-/babel-plugin-macros-3.1.0.tgz", + "integrity": "sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==", + "dependencies": { + "@babel/runtime": "^7.12.5", + "cosmiconfig": "^7.0.0", + "resolve": "^1.19.0" + }, + "engines": { + "node": ">=10", + "npm": ">=6" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==" + }, + "node_modules/bare-events": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/bare-events/-/bare-events-2.6.1.tgz", + "integrity": "sha512-AuTJkq9XmE6Vk0FJVNq5QxETrSA/vKHarWVBG5l/JbdCL1prJemiyJqUS0jrlXO0MftuPq4m3YVYhoNc5+aE/g==", + "optional": true + }, + "node_modules/base64-js": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", + "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/binary-extensions": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", + "dev": true, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/bindings": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/bindings/-/bindings-1.5.0.tgz", + "integrity": "sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ==", + "license": "MIT", + "dependencies": { + "file-uri-to-path": "1.0.0" + } + }, + "node_modules/bl": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz", + "integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==", + "license": "MIT", + "dependencies": { + "buffer": "^5.5.0", + "inherits": "^2.0.4", + "readable-stream": "^3.4.0" + } + }, + "node_modules/brace-expansion": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.1.1.tgz", + "integrity": "sha512-WR1cURNjuvBLMZBMbqM0UoE+WAfdUcEV1ccD8PVBVOI+Z3ND4+SZbN8RsfT2bMuG1qwz5RFvPukSZm5fF2D5eA==", + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "dev": true, + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/buffer": { + "version": "5.7.1", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz", + "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.1.13" + } + }, + "node_modules/buffer-crc32": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/buffer-crc32/-/buffer-crc32-1.0.0.tgz", + "integrity": "sha512-Db1SbgBS/fg/392AblrMJk97KggmvYhr4pB5ZIMTWtaivCPMWLkmb7m21cJvpvgK+J3nsU2CmmixNBZx4vFj/w==", + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "dev": true + }, + "node_modules/cacache": { + "version": "15.3.0", + "resolved": "https://registry.npmjs.org/cacache/-/cacache-15.3.0.tgz", + "integrity": "sha512-VVdYzXEn+cnbXpFgWs5hTT7OScegHVmLhJIR8Ufqk3iFD6A6j5iSX1KuBTfNEv4tdJWE2PzA6IVFtcLC7fN9wQ==", + "license": "ISC", + "optional": true, + "dependencies": { + "@npmcli/fs": "^1.0.0", + "@npmcli/move-file": "^1.0.1", + "chownr": "^2.0.0", + "fs-minipass": "^2.0.0", + "glob": "^7.1.4", + "infer-owner": "^1.0.4", + "lru-cache": "^6.0.0", + "minipass": "^3.1.1", + "minipass-collect": "^1.0.2", + "minipass-flush": "^1.0.5", + "minipass-pipeline": "^1.2.2", + "mkdirp": "^1.0.3", + "p-map": "^4.0.0", + "promise-inflight": "^1.0.1", + "rimraf": "^3.0.2", + "ssri": "^8.0.1", + "tar": "^6.0.2", + "unique-filename": "^1.1.1" + }, + "engines": { + "node": ">= 10" + } + }, + "node_modules/cacache/node_modules/brace-expansion": { + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.15.tgz", + "integrity": "sha512-EwOCDEex4quD37XhqM3omwtMoJjr//isUZz1JopUNWms+4Z2ViyM/k1YIRePpoVNnQhENnxtFjLaxNHrT7xIUg==", + "license": "MIT", + "optional": true, + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/cacache/node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "license": "ISC", + "optional": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/cacache/node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/cacache/node_modules/minimatch": { + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", + "license": "ISC", + "optional": true, + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/cacache/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "engines": { + "node": ">=6" + } + }, + "node_modules/camelcase-css": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", + "integrity": "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001700", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001700.tgz", + "integrity": "sha512-2S6XIXwaE7K7erT8dY+kLQcpa5ms63XlRkMkReXjle+kf6c5g38vyMl+Z5y8dSxOFDhcFe+nxnn261PLxBSQsQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ] + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chalk/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/chalk/node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/chokidar": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", + "dev": true, + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/chokidar/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/chownr": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", + "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", + "license": "ISC", + "engines": { + "node": ">=10" + } + }, + "node_modules/classnames": { + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/classnames/-/classnames-2.5.1.tgz", + "integrity": "sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow==" + }, + "node_modules/clean-stack": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", + "integrity": "sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/cli-boxes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/cli-boxes/-/cli-boxes-3.0.0.tgz", + "integrity": "sha512-/lzGpEWL/8PfI0BmBOPRwp0c/wFNX1RdUML3jK/RcSBA9T8mZDdQpqYBKtCFTOfQbwPqWEOpjqW+Fnayc0969g==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/cli-cursor": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/cli-cursor/-/cli-cursor-4.0.0.tgz", + "integrity": "sha512-VGtlMu3x/4DOtIUwEkRezxUZ2lBacNJCHash0N0WeZDBS+7Ux1dm3XWAgWYxLJFMMdOeXMHXorshEFhbMSGelg==", + "license": "MIT", + "optional": true, + "dependencies": { + "restore-cursor": "^4.0.0" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/cli-truncate": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/cli-truncate/-/cli-truncate-5.2.0.tgz", + "integrity": "sha512-xRwvIOMGrfOAnM1JYtqQImuaNtDEv9v6oIYAs4LIHwTiKee8uwvIi363igssOC0O5U04i4AlENs79LQLu9tEMw==", + "license": "MIT", + "optional": true, + "dependencies": { + "slice-ansi": "^8.0.0", + "string-width": "^8.2.0" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/cli-truncate/node_modules/string-width": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-8.2.0.tgz", + "integrity": "sha512-6hJPQ8N0V0P3SNmP6h2J99RLuzrWz2gvT7VnK5tKvrNqJoyS9W4/Fb8mo31UiPvy00z7DQXkP2hnKBVav76thw==", + "license": "MIT", + "optional": true, + "dependencies": { + "get-east-asian-width": "^1.5.0", + "strip-ansi": "^7.1.2" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/client-only": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz", + "integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==" + }, + "node_modules/cliui": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", + "integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==", + "dev": true, + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.1", + "wrap-ansi": "^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/cliui/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/cliui/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "node_modules/cliui/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/wrap-ansi": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/clone": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/clone/-/clone-2.1.2.tgz", + "integrity": "sha512-3Pe/CF1Nn94hyhIYpjtiLhdCoEoz0DqQ+988E9gmeEdQZlojxnOb74wctFyuwWQHzqyf9X7C7MG8juUpqBJT8w==", + "engines": { + "node": ">=0.8" + } + }, + "node_modules/clsx": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz", + "integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==", + "engines": { + "node": ">=6" + } + }, + "node_modules/code-excerpt": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/code-excerpt/-/code-excerpt-4.0.0.tgz", + "integrity": "sha512-xxodCmBen3iy2i0WtAK8FlFNrRzjUqjRsMfho58xT/wvZU1YTM3fCnRjcy1gJPMepaRlgm/0e6w8SpWHpn3/cA==", + "license": "MIT", + "optional": true, + "dependencies": { + "convert-to-spaces": "^2.0.1" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" + }, + "node_modules/color-support": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/color-support/-/color-support-1.1.3.tgz", + "integrity": "sha512-qiBjkpbMLO/HL68y+lh4q0/O1MZFj2RX6X/KmMa3+gJD3z+WwI1ZzDHysvqHGS3mP6mznPckpXmw1nI9cJjyRg==", + "license": "ISC", + "optional": true, + "bin": { + "color-support": "bin.js" + } + }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/commander": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", + "integrity": "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/compress-commons": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/compress-commons/-/compress-commons-6.0.2.tgz", + "integrity": "sha512-6FqVXeETqWPoGcfzrXb37E50NP0LXT8kAMu5ooZayhWWdgEY4lBEEcbQNXtkuKQsGduxiIcI4gOTsxTmuq/bSg==", + "dependencies": { + "crc-32": "^1.2.0", + "crc32-stream": "^6.0.0", + "is-stream": "^2.0.1", + "normalize-path": "^3.0.0", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/compress-commons/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/compress-commons/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/concurrently": { + "version": "9.1.2", + "resolved": "https://registry.npmjs.org/concurrently/-/concurrently-9.1.2.tgz", + "integrity": "sha512-H9MWcoPsYddwbOGM6difjVwVZHl63nwMEwDJG/L7VGtuaJhb12h2caPG2tVPWs7emuYix252iGfqOyrz1GczTQ==", + "dev": true, + "dependencies": { + "chalk": "^4.1.2", + "lodash": "^4.17.21", + "rxjs": "^7.8.1", + "shell-quote": "^1.8.1", + "supports-color": "^8.1.1", + "tree-kill": "^1.2.2", + "yargs": "^17.7.2" + }, + "bin": { + "conc": "dist/bin/concurrently.js", + "concurrently": "dist/bin/concurrently.js" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/open-cli-tools/concurrently?sponsor=1" + } + }, + "node_modules/console-control-strings": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", + "integrity": "sha512-ty/fTekppD2fIwRvnZAVdeOiGd1c7YXEixbgJTNzqcxJWKQnjJ/V1bNEEE6hygpM3WjwHFUVK6HTjWSzV4a8sQ==", + "license": "ISC", + "optional": true + }, + "node_modules/convert-source-map": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-1.9.0.tgz", + "integrity": "sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A==" + }, + "node_modules/convert-to-spaces": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/convert-to-spaces/-/convert-to-spaces-2.0.1.tgz", + "integrity": "sha512-rcQ1bsQO9799wq24uE5AM2tAILy4gXGIK/njFWcVQkGNZ96edlpY+A7bjwvzjYvLDyzmG1MmMLZhpcsb+klNMQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + } + }, + "node_modules/core-util-is": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==" + }, + "node_modules/cosmiconfig": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-7.1.0.tgz", + "integrity": "sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==", + "dependencies": { + "@types/parse-json": "^4.0.0", + "import-fresh": "^3.2.1", + "parse-json": "^5.0.0", + "path-type": "^4.0.0", + "yaml": "^1.10.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/cosmiconfig/node_modules/yaml": { + "version": "1.10.3", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.3.tgz", + "integrity": "sha512-vIYeF1u3CjlhAFekPPAk2h/Kv4T3mAkMox5OymRiJQB0spDP10LHvt+K7G9Ny6NuuMAb25/6n1qyUjAcGNf/AA==", + "license": "ISC", + "engines": { + "node": ">= 6" + } + }, + "node_modules/crc-32": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/crc-32/-/crc-32-1.2.2.tgz", + "integrity": "sha512-ROmzCKrTnOwybPcJApAA6WBWij23HVfGVNKqqrZpuyZOHqK2CwHSvpGuyt/UNNvaIjEd8X5IFGp4Mh+Ie1IHJQ==", + "bin": { + "crc32": "bin/crc32.njs" + }, + "engines": { + "node": ">=0.8" + } + }, + "node_modules/crc32-stream": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/crc32-stream/-/crc32-stream-6.0.0.tgz", + "integrity": "sha512-piICUB6ei4IlTv1+653yq5+KoqfBYmj9bw6LqXoOneTMDXk5nM1qt12mFW1caG3LlJXEKW1Bp0WggEmIfQB34g==", + "dependencies": { + "crc-32": "^1.2.0", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/crc32-stream/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/crc32-stream/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/create-require": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", + "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", + "dev": true + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/cssesc": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", + "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", + "dev": true, + "bin": { + "cssesc": "bin/cssesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/csstype": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", + "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==" + }, + "node_modules/debug": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.0.tgz", + "integrity": "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/decompress-response": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz", + "integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==", + "license": "MIT", + "dependencies": { + "mimic-response": "^3.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/deep-extend": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz", + "integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==", + "license": "MIT", + "engines": { + "node": ">=4.0.0" + } + }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/delegates": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delegates/-/delegates-1.0.0.tgz", + "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", + "license": "MIT", + "optional": true + }, + "node_modules/detect-libc": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.1.2.tgz", + "integrity": "sha512-Btj2BOOO83o3WyH59e8MgXsxEQVcarkUOpEYrubB0urwnN10yQ364rsiByU11nZlqWYZm05i/of7io4mzihBtQ==", + "license": "Apache-2.0", + "engines": { + "node": ">=8" + } + }, + "node_modules/didyoumean": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", + "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==", + "dev": true + }, + "node_modules/diff": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.4.tgz", + "integrity": "sha512-X07nttJQkwkfKfvTPG/KSnE2OMdcUCao6+eXF3wmnIQRn2aPAHH3VxDbDOdegkd6JbPsXqShpvEOHfAT+nCNwQ==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.3.1" + } + }, + "node_modules/dlv": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", + "integrity": "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==", + "dev": true + }, + "node_modules/dom-helpers": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/dom-helpers/-/dom-helpers-5.2.1.tgz", + "integrity": "sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==", + "dependencies": { + "@babel/runtime": "^7.8.7", + "csstype": "^3.0.2" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/dynamic-dedupe": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/dynamic-dedupe/-/dynamic-dedupe-0.3.0.tgz", + "integrity": "sha512-ssuANeD+z97meYOqd50e04Ze5qp4bPqo8cCkI4TRjZkzAUgIDTrXV1R8QCdINpiI+hw14+rYazvTRdQrz0/rFQ==", + "dev": true, + "dependencies": { + "xtend": "^4.0.0" + } + }, + "node_modules/eastasianwidth": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==" + }, + "node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==" + }, + "node_modules/encoding": { + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/encoding/-/encoding-0.1.13.tgz", + "integrity": "sha512-ETBauow1T35Y/WZMkio9jiM0Z5xjHHmJ4XmjZOq1l/dXz3lr2sRn87nJy20RupqSh1F2m3HHPSp8ShIPQJrJ3A==", + "license": "MIT", + "optional": true, + "dependencies": { + "iconv-lite": "^0.6.2" + } + }, + "node_modules/end-of-stream": { + "version": "1.4.4", + "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz", + "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==", + "license": "MIT", + "dependencies": { + "once": "^1.4.0" + } + }, + "node_modules/env-paths": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/env-paths/-/env-paths-2.2.1.tgz", + "integrity": "sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/environment": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/environment/-/environment-1.1.0.tgz", + "integrity": "sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/err-code": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/err-code/-/err-code-2.0.3.tgz", + "integrity": "sha512-2bmlRpNKBxT/CRmPOlyISQpNj+qSeYvcym/uT0Jx2bMOlKLtSy1ZmLuVxSEKKyor/N5yhvp/ZiG1oE3DEYMSFA==", + "license": "MIT", + "optional": true + }, + "node_modules/error-ex": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.2.tgz", + "integrity": "sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==", + "dependencies": { + "is-arrayish": "^0.2.1" + } + }, + "node_modules/error-ex/node_modules/is-arrayish": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", + "integrity": "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==" + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.2.tgz", + "integrity": "sha512-HWcBoN6NileqtSydK2FqHbS/LoDd2pqrnQHLyJzBj4kOp/ky2MWMN694xOfkK8/SnUsW2DH7EfyVlydKCsm1Zw==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-toolkit": { + "version": "1.43.0", + "resolved": "https://registry.npmjs.org/es-toolkit/-/es-toolkit-1.43.0.tgz", + "integrity": "sha512-SKCT8AsWvYzBBuUqMk4NPwFlSdqLpJwmy6AP322ERn8W2YLIB6JBXnwMI2Qsh2gfphT3q7EKAxKb23cvFHFwKA==", + "license": "MIT", + "optional": true, + "workspaces": [ + "docs", + "benchmarks" + ] + }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/event-target-shim": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/event-target-shim/-/event-target-shim-5.0.1.tgz", + "integrity": "sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==", + "engines": { + "node": ">=6" + } + }, + "node_modules/events": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/events/-/events-3.3.0.tgz", + "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", + "engines": { + "node": ">=0.8.x" + } + }, + "node_modules/expand-template": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz", + "integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==", + "license": "(MIT OR WTFPL)", + "engines": { + "node": ">=6" + } + }, + "node_modules/fast-fifo": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", + "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==" + }, + "node_modules/fast-glob": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", + "dev": true, + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.8" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fastq": { + "version": "1.19.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.0.tgz", + "integrity": "sha512-7SFSRCNjBQIZH/xZR3iy5iQYR8aGBE0h3VG6/cwlbrpdciNYBMotQav8c1XI3HjHH+NikUpP53nPdlZSdWmFzA==", + "dev": true, + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/file-selector": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/file-selector/-/file-selector-2.1.2.tgz", + "integrity": "sha512-QgXo+mXTe8ljeqUFaX3QVHc5osSItJ/Km+xpocx0aSqWGMSCf6qYs/VnzZgS864Pjn5iceMRFigeAV7AfTlaig==", + "dependencies": { + "tslib": "^2.7.0" + }, + "engines": { + "node": ">= 12" + } + }, + "node_modules/file-uri-to-path": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/file-uri-to-path/-/file-uri-to-path-1.0.0.tgz", + "integrity": "sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw==", + "license": "MIT" + }, + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "dev": true, + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/find-root": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/find-root/-/find-root-1.1.0.tgz", + "integrity": "sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==" + }, + "node_modules/follow-redirects": { + "version": "1.16.0", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.16.0.tgz", + "integrity": "sha512-y5rN/uOsadFT/JfYwhxRS5R7Qce+g3zG97+JrtFZlC9klX/W5hD7iiLzScI4nZqUS7DNUdhPgw4xI8W2LuXlUw==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/foreground-child": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.0.tgz", + "integrity": "sha512-Ld2g8rrAyMYFXBhEqMz8ZAHBi4J4uS1i/CxGMDnjyFWddMXLVcDp051DZfu+t7+ab7Wv6SMqpWmyFIj5UbfFvg==", + "dependencies": { + "cross-spawn": "^7.0.0", + "signal-exit": "^4.0.1" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/form-data": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.6.tgz", + "integrity": "sha512-vKatAh4SlVfgbv+YtmhiRjhEMJsYpsG1Y2rMQtR+SVSbytsSD1YGzDIcrAJmdFec88u/+VoGmxnl+80gL1tRCQ==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.4", + "mime-types": "^2.1.35" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fs-constants": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz", + "integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==", + "license": "MIT" + }, + "node_modules/fs-minipass": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", + "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", + "license": "ISC", + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/fs-minipass/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "devOptional": true, + "license": "ISC" + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/gauge": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/gauge/-/gauge-4.0.4.tgz", + "integrity": "sha512-f9m+BEN5jkg6a0fZjleidjN51VE1X+mPFQ2DJ0uv1V39oCLCbsGe6yjbBnp7eK7z/+GAon99a3nHuqbuuthyPg==", + "deprecated": "This package is no longer supported.", + "license": "ISC", + "optional": true, + "dependencies": { + "aproba": "^1.0.3 || ^2.0.0", + "color-support": "^1.1.3", + "console-control-strings": "^1.1.0", + "has-unicode": "^2.0.1", + "signal-exit": "^3.0.7", + "string-width": "^4.2.3", + "strip-ansi": "^6.0.1", + "wide-align": "^1.1.5" + }, + "engines": { + "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + } + }, + "node_modules/gauge/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/gauge/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT", + "optional": true + }, + "node_modules/gauge/node_modules/signal-exit": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", + "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", + "license": "ISC", + "optional": true + }, + "node_modules/gauge/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "optional": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/gauge/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "optional": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "dev": true, + "engines": { + "node": "6.* || 8.* || >= 10.*" + } + }, + "node_modules/get-east-asian-width": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/get-east-asian-width/-/get-east-asian-width-1.5.0.tgz", + "integrity": "sha512-CQ+bEO+Tva/qlmw24dCejulK5pMzVnUOFOijVogd3KQs07HnRIgp8TGipvCCRT06xeYEbpbgwaCxglFyiuIcmA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/github-from-package": { + "version": "0.0.0", + "resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz", + "integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==", + "license": "MIT" + }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/globals": { + "version": "11.12.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-11.12.0.tgz", + "integrity": "sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA==", + "engines": { + "node": ">=4" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "license": "ISC" + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-unicode": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", + "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", + "license": "ISC", + "optional": true + }, + "node_modules/hasown": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.4.tgz", + "integrity": "sha512-T2UbfbBEF32wiepXIsMlTW9+dDYC6wMh/t/vYA4tuOMKqWz/n3vr1NFSxQiyP+zk2mXsoMA/i/7qV6LKut1t1A==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/hoist-non-react-statics": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/hoist-non-react-statics/-/hoist-non-react-statics-3.3.2.tgz", + "integrity": "sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw==", + "dependencies": { + "react-is": "^16.7.0" + } + }, + "node_modules/http-cache-semantics": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", + "integrity": "sha512-er295DKPVsV82j5kw1Gjt+ADA/XYHsajl82cGNQG2eyoPkvgUhX+nDIyelzhIWbbsXP39EHcI6l5tYs2FYqYXQ==", + "license": "BSD-2-Clause", + "optional": true + }, + "node_modules/http-proxy-agent": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-4.0.1.tgz", + "integrity": "sha512-k0zdNgqWTGA6aeIRVpvfVob4fL52dTfaehylg0Y4UvSySvOq/Y+BOyPrgpUrA7HylqvU8vIZGsRuXmspskV0Tg==", + "license": "MIT", + "optional": true, + "dependencies": { + "@tootallnate/once": "1", + "agent-base": "6", + "debug": "4" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/https-proxy-agent": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-5.0.1.tgz", + "integrity": "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA==", + "license": "MIT", + "dependencies": { + "agent-base": "6", + "debug": "4" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/humanize-ms": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/humanize-ms/-/humanize-ms-1.2.1.tgz", + "integrity": "sha512-Fl70vYtsAFb/C06PTS9dZBo7ihau+Tu/DNCk/OyHhea07S+aeMWpFFkUaXRa8fI+ScZbEI8dfSxwY7gxZ9SAVQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "ms": "^2.0.0" + } + }, + "node_modules/iconv-lite": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz", + "integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==", + "license": "MIT", + "optional": true, + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/ieee754": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz", + "integrity": "sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/indent-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", + "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/infer-owner": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/infer-owner/-/infer-owner-1.0.4.tgz", + "integrity": "sha512-IClj+Xz94+d7irH5qRyfJonOdfTzuDaifE6ZPWfx0N0+/ATZCbuTPq2prFl526urkQd90WyUKIh1DfBQ2hMz9A==", + "license": "ISC", + "optional": true + }, + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.", + "devOptional": true, + "license": "ISC", + "dependencies": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "license": "ISC" + }, + "node_modules/ini": { + "version": "1.3.8", + "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz", + "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==", + "license": "ISC" + }, + "node_modules/ink": { + "version": "6.8.0", + "resolved": "https://registry.npmjs.org/ink/-/ink-6.8.0.tgz", + "integrity": "sha512-sbl1RdLOgkO9isK42WCZlJCFN9hb++sX9dsklOvfd1YQ3bQ2AiFu12Q6tFlr0HvEUvzraJntQCCpfEoUe9DSzA==", + "license": "MIT", + "optional": true, + "dependencies": { + "@alcalzone/ansi-tokenize": "^0.2.4", + "ansi-escapes": "^7.3.0", + "ansi-styles": "^6.2.1", + "auto-bind": "^5.0.1", + "chalk": "^5.6.0", + "cli-boxes": "^3.0.0", + "cli-cursor": "^4.0.0", + "cli-truncate": "^5.1.1", + "code-excerpt": "^4.0.0", + "es-toolkit": "^1.39.10", + "indent-string": "^5.0.0", + "is-in-ci": "^2.0.0", + "patch-console": "^2.0.0", + "react-reconciler": "^0.33.0", + "scheduler": "^0.27.0", + "signal-exit": "^3.0.7", + "slice-ansi": "^8.0.0", + "stack-utils": "^2.0.6", + "string-width": "^8.1.1", + "terminal-size": "^4.0.1", + "type-fest": "^5.4.1", + "widest-line": "^6.0.0", + "wrap-ansi": "^9.0.0", + "ws": "^8.18.0", + "yoga-layout": "~3.2.1" + }, + "engines": { + "node": ">=20" + }, + "peerDependencies": { + "@types/react": ">=19.0.0", + "react": ">=19.0.0", + "react-devtools-core": ">=6.1.2" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "react-devtools-core": { + "optional": true + } + } + }, + "node_modules/ink/node_modules/chalk": { + "version": "5.6.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-5.6.2.tgz", + "integrity": "sha512-7NzBL0rN6fMUW+f7A6Io4h40qQlG+xGmtMxfbnH/K7TAtt8JQWVQK+6g0UXKMeVJoyV5EkkNsErQ8pVD3bLHbA==", + "license": "MIT", + "optional": true, + "engines": { + "node": "^12.17.0 || ^14.13 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/ink/node_modules/emoji-regex": { + "version": "10.6.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-10.6.0.tgz", + "integrity": "sha512-toUI84YS5YmxW219erniWD0CIVOo46xGKColeNQRgOzDorgBi1v4D71/OFzgD9GO2UGKIv1C3Sp8DAn0+j5w7A==", + "license": "MIT", + "optional": true + }, + "node_modules/ink/node_modules/indent-string": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-5.0.0.tgz", + "integrity": "sha512-m6FAo/spmsW2Ab2fU35JTYwtOKa2yAwXSwgjSv1TJzh4Mh7mC3lzAOVLBprb72XsTrgkEIsl7YrFNAiDiRhIGg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ink/node_modules/react-reconciler": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/react-reconciler/-/react-reconciler-0.33.0.tgz", + "integrity": "sha512-KetWRytFv1epdpJc3J4G75I4WrplZE5jOL7Yq0p34+OVOKF4Se7WrdIdVC45XsSSmUTlht2FM/fM1FZb1mfQeA==", + "license": "MIT", + "optional": true, + "dependencies": { + "scheduler": "^0.27.0" + }, + "engines": { + "node": ">=0.10.0" + }, + "peerDependencies": { + "react": "^19.2.0" + } + }, + "node_modules/ink/node_modules/signal-exit": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", + "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", + "license": "ISC", + "optional": true + }, + "node_modules/ink/node_modules/string-width": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-8.2.0.tgz", + "integrity": "sha512-6hJPQ8N0V0P3SNmP6h2J99RLuzrWz2gvT7VnK5tKvrNqJoyS9W4/Fb8mo31UiPvy00z7DQXkP2hnKBVav76thw==", + "license": "MIT", + "optional": true, + "dependencies": { + "get-east-asian-width": "^1.5.0", + "strip-ansi": "^7.1.2" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ink/node_modules/wrap-ansi": { + "version": "9.0.2", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-9.0.2.tgz", + "integrity": "sha512-42AtmgqjV+X1VpdOfyTGOYRi0/zsoLqtXQckTmqTeybT+BDIbM/Guxo7x3pE2vtpr1ok6xRqM9OpBe+Jyoqyww==", + "license": "MIT", + "optional": true, + "dependencies": { + "ansi-styles": "^6.2.1", + "string-width": "^7.0.0", + "strip-ansi": "^7.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/ink/node_modules/wrap-ansi/node_modules/string-width": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-7.2.0.tgz", + "integrity": "sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "emoji-regex": "^10.3.0", + "get-east-asian-width": "^1.0.0", + "strip-ansi": "^7.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ip-address": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz", + "integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 12" + } + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "engines": { + "node": ">=8" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-in-ci": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/is-in-ci/-/is-in-ci-2.0.0.tgz", + "integrity": "sha512-cFeerHriAnhrQSbpAxL37W1wcJKUUX07HyLWZCW1URJT/ra3GyUTzBgUnh24TMVfNTV2Hij2HLxkPHFZfOZy5w==", + "license": "MIT", + "optional": true, + "bin": { + "is-in-ci": "cli.js" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-lambda": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/is-lambda/-/is-lambda-1.0.1.tgz", + "integrity": "sha512-z7CMFGNrENq5iFB9Bqo64Xk6Y9sg+epq1myIcdHaGnbMTYOxvzsEtdYqQUylB7LxfkvgrrjP32T6Ywciio9UIQ==", + "license": "MIT", + "optional": true + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==" + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==" + }, + "node_modules/jackspeak": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", + "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + }, + "optionalDependencies": { + "@pkgjs/parseargs": "^0.11.0" + } + }, + "node_modules/jiti": { + "version": "1.21.7", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", + "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", + "dev": true, + "bin": { + "jiti": "bin/jiti.js" + } + }, + "node_modules/js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==" + }, + "node_modules/jsesc": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/jsesc/-/jsesc-3.1.0.tgz", + "integrity": "sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA==", + "bin": { + "jsesc": "bin/jsesc" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==" + }, + "node_modules/lazystream": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/lazystream/-/lazystream-1.0.1.tgz", + "integrity": "sha512-b94GiNHQNy6JNTrt5w6zNyffMrNkXZb3KTkCZJb2V1xaEGCk093vkZ2jk3tpaeP33/OiXC+WvK9AxUebnf5nbw==", + "dependencies": { + "readable-stream": "^2.0.5" + }, + "engines": { + "node": ">= 0.6.3" + } + }, + "node_modules/lazystream/node_modules/readable-stream": { + "version": "2.3.8", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.8.tgz", + "integrity": "sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==", + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/lazystream/node_modules/safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" + }, + "node_modules/lazystream/node_modules/string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, + "node_modules/lilconfig": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz", + "integrity": "sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==", + "dev": true, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/antonk52" + } + }, + "node_modules/lines-and-columns": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", + "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==" + }, + "node_modules/lodash": { + "version": "4.18.1", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz", + "integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==", + "license": "MIT" + }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, + "node_modules/lru-cache": { + "version": "10.4.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==" + }, + "node_modules/lucide-react": { + "version": "0.475.0", + "resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.475.0.tgz", + "integrity": "sha512-NJzvVu1HwFVeZ+Gwq2q00KygM1aBhy/ZrhY9FsAgJtpB+E4R7uxRk9M2iKvHa6/vNxZydIB59htha4c2vvwvVg==", + "license": "ISC", + "peerDependencies": { + "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/macstats": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/macstats/-/macstats-4.2.0.tgz", + "integrity": "sha512-+NJmIPjndK62WOwmu4Qbbgj5K84rTON7j3kBMOyFERsEztA7yM1ITyBRQQ/Y7cvRyYVeMi11C6KH7CFhIZlc2A==", + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "dependencies": { + "chalk": "^5.6.2", + "ink": "^6.4.0", + "nan": "^2.23.1", + "react": "^19.2.0" + }, + "bin": { + "macstats": "bin/macstats" + } + }, + "node_modules/macstats/node_modules/chalk": { + "version": "5.6.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-5.6.2.tgz", + "integrity": "sha512-7NzBL0rN6fMUW+f7A6Io4h40qQlG+xGmtMxfbnH/K7TAtt8JQWVQK+6g0UXKMeVJoyV5EkkNsErQ8pVD3bLHbA==", + "license": "MIT", + "optional": true, + "engines": { + "node": "^12.17.0 || ^14.13 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/make-error": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz", + "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==", + "dev": true + }, + "node_modules/make-fetch-happen": { + "version": "9.1.0", + "resolved": "https://registry.npmjs.org/make-fetch-happen/-/make-fetch-happen-9.1.0.tgz", + "integrity": "sha512-+zopwDy7DNknmwPQplem5lAZX/eCOzSvSNNcSKm5eVwTkOBzoktEfXsa9L23J/GIRhxRsaxzkPEhrJEpE2F4Gg==", + "license": "ISC", + "optional": true, + "dependencies": { + "agentkeepalive": "^4.1.3", + "cacache": "^15.2.0", + "http-cache-semantics": "^4.1.0", + "http-proxy-agent": "^4.0.1", + "https-proxy-agent": "^5.0.0", + "is-lambda": "^1.0.1", + "lru-cache": "^6.0.0", + "minipass": "^3.1.3", + "minipass-collect": "^1.0.2", + "minipass-fetch": "^1.3.2", + "minipass-flush": "^1.0.5", + "minipass-pipeline": "^1.2.4", + "negotiator": "^0.6.2", + "promise-retry": "^2.0.1", + "socks-proxy-agent": "^6.0.0", + "ssri": "^8.0.0" + }, + "engines": { + "node": ">= 10" + } + }, + "node_modules/make-fetch-happen/node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/make-fetch-happen/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/memoize-one": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/memoize-one/-/memoize-one-6.0.0.tgz", + "integrity": "sha512-rkpe71W0N0c0Xz6QD0eJETuWAJGnJ9afsl1srmwPrI+yBCkge5EycXXbYRyvL29zZVUWQCY7InPRCv3GDXuZNw==" + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/micromatch": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "dev": true, + "dependencies": { + "braces": "^3.0.3", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mimic-fn": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", + "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/mimic-response": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz", + "integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==", + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/minimatch": { + "version": "9.0.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.9.tgz", + "integrity": "sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==", + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.2" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/minimist": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", + "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/minipass": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/minipass-collect": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/minipass-collect/-/minipass-collect-1.0.2.tgz", + "integrity": "sha512-6T6lH0H8OG9kITm/Jm6tdooIbogG9e0tLgpY6mphXSm/A9u8Nq1ryBG+Qspiub9LjWlBPsPS3tWQ/Botq4FdxA==", + "license": "ISC", + "optional": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/minipass-collect/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-fetch": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/minipass-fetch/-/minipass-fetch-1.4.1.tgz", + "integrity": "sha512-CGH1eblLq26Y15+Azk7ey4xh0J/XfJfrCox5LDJiKqI2Q2iwOLOKrlmIaODiSQS8d18jalF6y2K2ePUm0CmShw==", + "license": "MIT", + "optional": true, + "dependencies": { + "minipass": "^3.1.0", + "minipass-sized": "^1.0.3", + "minizlib": "^2.0.0" + }, + "engines": { + "node": ">=8" + }, + "optionalDependencies": { + "encoding": "^0.1.12" + } + }, + "node_modules/minipass-fetch/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-flush": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/minipass-flush/-/minipass-flush-1.0.5.tgz", + "integrity": "sha512-JmQSYYpPUqX5Jyn1mXaRwOda1uQ8HP5KAT/oDSLCzt1BYRhQU0/hDtsB1ufZfEEzMZ9aAVmsBw8+FWsIXlClWw==", + "license": "ISC", + "optional": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/minipass-flush/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-pipeline": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/minipass-pipeline/-/minipass-pipeline-1.2.4.tgz", + "integrity": "sha512-xuIq7cIOt09RPRJ19gdi4b+RiNvDFYe5JH+ggNvBqGqpQXcru3PcRmOZuHBKWK1Txf9+cQ+HMVN4d6z46LZP7A==", + "license": "ISC", + "optional": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-pipeline/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-sized": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/minipass-sized/-/minipass-sized-1.0.3.tgz", + "integrity": "sha512-MbkQQ2CTiBMlA2Dm/5cY+9SWFEN8pzzOXi6rlM5Xxq0Yqbda5ZQy9sU75a673FE9ZK0Zsbr6Y5iP6u9nktfg2g==", + "license": "ISC", + "optional": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-sized/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minizlib": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", + "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", + "license": "MIT", + "dependencies": { + "minipass": "^3.0.0", + "yallist": "^4.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/minizlib/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/mkdirp": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", + "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", + "license": "MIT", + "bin": { + "mkdirp": "bin/cmd.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/mkdirp-classic": { + "version": "0.5.3", + "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz", + "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==", + "license": "MIT" + }, + "node_modules/monaco-editor": { + "version": "0.52.2", + "resolved": "https://registry.npmjs.org/monaco-editor/-/monaco-editor-0.52.2.tgz", + "integrity": "sha512-GEQWEZmfkOGLdd3XK8ryrfWz3AIP8YymVXiPHEdewrUq7mh0qrKrfHLNCXcbB6sTnMLnOZ3ztSiKcciFUkIJwQ==", + "peer": true + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/mz": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz", + "integrity": "sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==", + "dev": true, + "dependencies": { + "any-promise": "^1.0.0", + "object-assign": "^4.0.1", + "thenify-all": "^1.0.0" + } + }, + "node_modules/nan": { + "version": "2.26.2", + "resolved": "https://registry.npmjs.org/nan/-/nan-2.26.2.tgz", + "integrity": "sha512-0tTvBTYkt3tdGw22nrAy50x7gpbGCCFH3AFcyS5WiUu7Eu4vWlri1woE6qHBSfy11vksDqkiwjOnlR7WV8G1Hw==", + "license": "MIT", + "optional": true + }, + "node_modules/nanoid": { + "version": "3.3.12", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.12.tgz", + "integrity": "sha512-ZB9RH/39qpq5Vu6Y+NmUaFhQR6pp+M2Xt76XBnEwDaGcVAqhlvxrl3B2bKS5D3NH3QR76v3aSrKaF/Kiy7lEtQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/napi-build-utils": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-2.0.0.tgz", + "integrity": "sha512-GEbrYkbfF7MoNaoh2iGG84Mnf/WZfB0GdGEsM8wz7Expx/LlWf5U8t9nvJKXSp3qr5IsEbK04cBGhol/KwOsWA==", + "license": "MIT" + }, + "node_modules/negotiator": { + "version": "0.6.4", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.4.tgz", + "integrity": "sha512-myRT3DiWPHqho5PrJaIRyaMv2kgYf0mUVgBNOYMuCH5Ki1yEiQaf/ZJuQ62nvpc44wL5WDbTX7yGJi1Neevw8w==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/next": { + "version": "15.5.19", + "resolved": "https://registry.npmjs.org/next/-/next-15.5.19.tgz", + "integrity": "sha512-xNOW6tYshGX1/Oi3F8uuk4gpDeWsSUE/1Z0G5uUMekIxaQ0xc03UXd9II0VQHYMWviMeA0OHpJFAKsHf8bTYVg==", + "license": "MIT", + "dependencies": { + "@next/env": "15.5.19", + "@swc/helpers": "0.5.15", + "caniuse-lite": "^1.0.30001579", + "postcss": "8.4.31", + "styled-jsx": "5.1.6" + }, + "bin": { + "next": "dist/bin/next" + }, + "engines": { + "node": "^18.18.0 || ^19.8.0 || >= 20.0.0" + }, + "optionalDependencies": { + "@next/swc-darwin-arm64": "15.5.19", + "@next/swc-darwin-x64": "15.5.19", + "@next/swc-linux-arm64-gnu": "15.5.19", + "@next/swc-linux-arm64-musl": "15.5.19", + "@next/swc-linux-x64-gnu": "15.5.19", + "@next/swc-linux-x64-musl": "15.5.19", + "@next/swc-win32-arm64-msvc": "15.5.19", + "@next/swc-win32-x64-msvc": "15.5.19", + "sharp": "^0.34.3" + }, + "peerDependencies": { + "@opentelemetry/api": "^1.1.0", + "@playwright/test": "^1.51.1", + "babel-plugin-react-compiler": "*", + "react": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0", + "react-dom": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0", + "sass": "^1.3.0" + }, + "peerDependenciesMeta": { + "@opentelemetry/api": { + "optional": true + }, + "@playwright/test": { + "optional": true + }, + "babel-plugin-react-compiler": { + "optional": true + }, + "sass": { + "optional": true + } + } + }, + "node_modules/next/node_modules/postcss": { + "version": "8.4.31", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz", + "integrity": "sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "nanoid": "^3.3.6", + "picocolors": "^1.0.0", + "source-map-js": "^1.0.2" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/node-abi": { + "version": "3.74.0", + "resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.74.0.tgz", + "integrity": "sha512-c5XK0MjkGBrQPGYG24GBADZud0NCbznxNx0ZkS+ebUTrmV1qTDxPxSL8zEAPURXSbLRWVexxmP4986BziahL5w==", + "license": "MIT", + "dependencies": { + "semver": "^7.3.5" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/node-addon-api": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-7.1.1.tgz", + "integrity": "sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==", + "license": "MIT" + }, + "node_modules/node-cache": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/node-cache/-/node-cache-5.1.2.tgz", + "integrity": "sha512-t1QzWwnk4sjLWaQAS8CHgOJ+RAfmHpxFWmc36IWTiWHQfs0w5JDMBS1b1ZxQteo0vVVuWJvIUKHDkkeK7vIGCg==", + "dependencies": { + "clone": "2.x" + }, + "engines": { + "node": ">= 8.0.0" + } + }, + "node_modules/node-gyp": { + "version": "8.4.1", + "resolved": "https://registry.npmjs.org/node-gyp/-/node-gyp-8.4.1.tgz", + "integrity": "sha512-olTJRgUtAb/hOXG0E93wZDs5YiJlgbXxTwQAFHyNlRsXQnYzUaF2aGgujZbw+hR8aF4ZG/rST57bWMWD16jr9w==", + "license": "MIT", + "optional": true, + "dependencies": { + "env-paths": "^2.2.0", + "glob": "^7.1.4", + "graceful-fs": "^4.2.6", + "make-fetch-happen": "^9.1.0", + "nopt": "^5.0.0", + "npmlog": "^6.0.0", + "rimraf": "^3.0.2", + "semver": "^7.3.5", + "tar": "^6.1.2", + "which": "^2.0.2" + }, + "bin": { + "node-gyp": "bin/node-gyp.js" + }, + "engines": { + "node": ">= 10.12.0" + } + }, + "node_modules/node-gyp/node_modules/brace-expansion": { + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.15.tgz", + "integrity": "sha512-EwOCDEex4quD37XhqM3omwtMoJjr//isUZz1JopUNWms+4Z2ViyM/k1YIRePpoVNnQhENnxtFjLaxNHrT7xIUg==", + "license": "MIT", + "optional": true, + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/node-gyp/node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "license": "ISC", + "optional": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/node-gyp/node_modules/minimatch": { + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", + "license": "ISC", + "optional": true, + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/nopt": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/nopt/-/nopt-5.0.0.tgz", + "integrity": "sha512-Tbj67rffqceeLpcRXrT7vKAN8CwfPeIBgM7E6iBkmKLV7bEMwpGgYLGv0jACUsECaa/vuxP0IjEont6umdMgtQ==", + "license": "ISC", + "optional": true, + "dependencies": { + "abbrev": "1" + }, + "bin": { + "nopt": "bin/nopt.js" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/npmlog": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/npmlog/-/npmlog-6.0.2.tgz", + "integrity": "sha512-/vBvz5Jfr9dT/aFWd0FIRf+T/Q2WBsLENygUaFUqstqsycmZAP/t5BvFJTK0viFmSUxiUKTUplWy5vt+rvKIxg==", + "deprecated": "This package is no longer supported.", + "license": "ISC", + "optional": true, + "dependencies": { + "are-we-there-yet": "^3.0.0", + "console-control-strings": "^1.1.0", + "gauge": "^4.0.3", + "set-blocking": "^2.0.0" + }, + "engines": { + "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-hash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz", + "integrity": "sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/onetime": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", + "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", + "license": "MIT", + "optional": true, + "dependencies": { + "mimic-fn": "^2.1.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-map": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/p-map/-/p-map-4.0.0.tgz", + "integrity": "sha512-/bjOqmgETBYB5BoEeGVea8dmvHb2m9GLy1E9W43yeyfP6QQCZGFNa+XRceJEuDB6zqr+gKpIAmlLebMpykw/MQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "aggregate-error": "^3.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/package-json-from-dist": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==" + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/parse-json": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-5.2.0.tgz", + "integrity": "sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==", + "dependencies": { + "@babel/code-frame": "^7.0.0", + "error-ex": "^1.3.1", + "json-parse-even-better-errors": "^2.3.0", + "lines-and-columns": "^1.1.6" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/patch-console": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/patch-console/-/patch-console-2.0.0.tgz", + "integrity": "sha512-0YNdUceMdaQwoKce1gatDScmMo5pu/tfABfnzEqeG0gtTmd7mh/WcwgUjtAeOU7N8nFFlbQBnFK2gXW5fGvmMA==", + "license": "MIT", + "optional": true, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + } + }, + "node_modules/path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==" + }, + "node_modules/path-scurry": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", + "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", + "dependencies": { + "lru-cache": "^10.2.0", + "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" + }, + "engines": { + "node": ">=16 || 14 >=14.18" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/path-type": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "engines": { + "node": ">=8" + } + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==" + }, + "node_modules/picomatch": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz", + "integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pify": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", + "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pirates": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.6.tgz", + "integrity": "sha512-saLsH7WeYYPiD25LDuLRRY/i+6HaPYr6G1OUlN39otzkSTxKnubR9RTxS3/Kk50s1g2JTgFwWQDQyplC5/SHZg==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/postcss": { + "version": "8.5.15", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.15.tgz", + "integrity": "sha512-FfR8sjd4em2T6fb3I2MwAJU7HWVMr9zba+enmQeeWFfCbm+UOC/0X4DS8XtpUTMwWMGbjKYP7xjfNekzyGmB3A==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.12", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/postcss-import": { + "version": "15.1.0", + "resolved": "https://registry.npmjs.org/postcss-import/-/postcss-import-15.1.0.tgz", + "integrity": "sha512-hpr+J05B2FVYUAXHeK1YyI267J/dDDhMU6B6civm8hSY1jYJnBXxzKDKDswzJmtLHryrjhnDjqqp/49t8FALew==", + "dev": true, + "dependencies": { + "postcss-value-parser": "^4.0.0", + "read-cache": "^1.0.0", + "resolve": "^1.1.7" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "postcss": "^8.0.0" + } + }, + "node_modules/postcss-js": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/postcss-js/-/postcss-js-4.0.1.tgz", + "integrity": "sha512-dDLF8pEO191hJMtlHFPRa8xsizHaM82MLfNkUHdUtVEV3tgTp5oj+8qbEqYM57SLfc74KSbw//4SeJma2LRVIw==", + "dev": true, + "dependencies": { + "camelcase-css": "^2.0.1" + }, + "engines": { + "node": "^12 || ^14 || >= 16" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": "^8.4.21" + } + }, + "node_modules/postcss-load-config": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/postcss-load-config/-/postcss-load-config-4.0.2.tgz", + "integrity": "sha512-bSVhyJGL00wMVoPUzAVAnbEoWyqRxkjv64tUl427SKnPrENtq6hJwUojroMz2VB+Q1edmi4IfrAPpami5VVgMQ==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "lilconfig": "^3.0.0", + "yaml": "^2.3.4" + }, + "engines": { + "node": ">= 14" + }, + "peerDependencies": { + "postcss": ">=8.0.9", + "ts-node": ">=9.0.0" + }, + "peerDependenciesMeta": { + "postcss": { + "optional": true + }, + "ts-node": { + "optional": true + } + } + }, + "node_modules/postcss-nested": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/postcss-nested/-/postcss-nested-6.2.0.tgz", + "integrity": "sha512-HQbt28KulC5AJzG+cZtj9kvKB93CFCdLvog1WFLf1D+xmMvPGlBstkpTEZfK5+AN9hfJocyBFCNiqyS48bpgzQ==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "postcss-selector-parser": "^6.1.1" + }, + "engines": { + "node": ">=12.0" + }, + "peerDependencies": { + "postcss": "^8.2.14" + } + }, + "node_modules/postcss-selector-parser": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.2.tgz", + "integrity": "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==", + "dev": true, + "dependencies": { + "cssesc": "^3.0.0", + "util-deprecate": "^1.0.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/postcss-value-parser": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", + "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==", + "dev": true + }, + "node_modules/prebuild-install": { + "version": "7.1.3", + "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.3.tgz", + "integrity": "sha512-8Mf2cbV7x1cXPUILADGI3wuhfqWvtiLA1iclTDbFRZkgRQS0NqsPZphna9V+HyTEadheuPmjaJMsbzKQFOzLug==", + "license": "MIT", + "dependencies": { + "detect-libc": "^2.0.0", + "expand-template": "^2.0.3", + "github-from-package": "0.0.0", + "minimist": "^1.2.3", + "mkdirp-classic": "^0.5.3", + "napi-build-utils": "^2.0.0", + "node-abi": "^3.3.0", + "pump": "^3.0.0", + "rc": "^1.2.7", + "simple-get": "^4.0.0", + "tar-fs": "^2.0.0", + "tunnel-agent": "^0.6.0" + }, + "bin": { + "prebuild-install": "bin.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/prettier": { + "version": "3.5.1", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.5.1.tgz", + "integrity": "sha512-hPpFQvHwL3Qv5AdRvBFMhnKo4tYxp0ReXiPn2bxkiohEX6mBeBwEpBSQTkD458RaaDKQMYSp4hX4UtfUTA5wDw==", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/prettier-basic": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/prettier-basic/-/prettier-basic-1.0.0.tgz", + "integrity": "sha512-cBAeJbegnXLEOUX9q+xU5l8zOehkZkR9dG4VSrN95hwRqBrdGCPzYmxG9ojdgxGuX7Y2hkqKZq9tlIeAvCvOAA==", + "dev": true, + "license": "MIT" + }, + "node_modules/prisma": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/prisma/-/prisma-6.3.1.tgz", + "integrity": "sha512-JKCZWvBC3enxk51tY4TWzS4b5iRt4sSU1uHn2I183giZTvonXaQonzVtjLzpOHE7qu9MxY510kAtFGJwryKe3Q==", + "hasInstallScript": true, + "license": "Apache-2.0", + "dependencies": { + "@prisma/engines": "6.3.1" + }, + "bin": { + "prisma": "build/index.js" + }, + "engines": { + "node": ">=18.18" + }, + "optionalDependencies": { + "fsevents": "2.3.3" + }, + "peerDependencies": { + "typescript": ">=5.1.0" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/process": { + "version": "0.11.10", + "resolved": "https://registry.npmjs.org/process/-/process-0.11.10.tgz", + "integrity": "sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==", + "engines": { + "node": ">= 0.6.0" + } + }, + "node_modules/process-nextick-args": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" + }, + "node_modules/promise-inflight": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", + "integrity": "sha512-6zWPyEOFaQBJYcGMHBKTKJ3u6TBsnMFOIZSa6ce1e/ZrrsOlnHRHbabMjLiBYKp+n44X9eUI6VUPaukCXHuG4g==", + "license": "ISC", + "optional": true + }, + "node_modules/promise-retry": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/promise-retry/-/promise-retry-2.0.1.tgz", + "integrity": "sha512-y+WKFlBR8BGXnsNlIHFGPZmyDf3DFMoLhaflAnyZgV6rG6xu+JwesTo2Q9R6XwYmtmwAFCkAk3e35jEdoeh/3g==", + "license": "MIT", + "optional": true, + "dependencies": { + "err-code": "^2.0.2", + "retry": "^0.12.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, + "node_modules/proxy-from-env": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-2.1.0.tgz", + "integrity": "sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==", + "license": "MIT", + "engines": { + "node": ">=10" + } + }, + "node_modules/pump": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.2.tgz", + "integrity": "sha512-tUPXtzlGM8FE3P0ZL6DVs/3P58k9nk8/jZeQCurTJylQA8qFYzHFfhBJkuqyE0FifOsQ0uKWekiZ5g8wtr28cw==", + "license": "MIT", + "dependencies": { + "end-of-stream": "^1.1.0", + "once": "^1.3.1" + } + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/rc": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", + "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==", + "license": "(BSD-2-Clause OR MIT OR Apache-2.0)", + "dependencies": { + "deep-extend": "^0.6.0", + "ini": "~1.3.0", + "minimist": "^1.2.0", + "strip-json-comments": "~2.0.1" + }, + "bin": { + "rc": "cli.js" + } + }, + "node_modules/react": { + "version": "19.2.4", + "resolved": "https://registry.npmjs.org/react/-/react-19.2.4.tgz", + "integrity": "sha512-9nfp2hYpCwOjAN+8TZFGhtWEwgvWHXqESH8qT89AT/lWklpLON22Lc8pEtnpsZz7VmawabSU0gCjnj8aC0euHQ==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-dom": { + "version": "19.2.4", + "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.4.tgz", + "integrity": "sha512-AXJdLo8kgMbimY95O2aKQqsz2iWi9jMgKJhRBAxECE4IFxfcazB2LmzloIoibJI3C12IlY20+KFaLv+71bUJeQ==", + "license": "MIT", + "dependencies": { + "scheduler": "^0.27.0" + }, + "peerDependencies": { + "react": "^19.2.4" + } + }, + "node_modules/react-dropzone": { + "version": "14.3.5", + "resolved": "https://registry.npmjs.org/react-dropzone/-/react-dropzone-14.3.5.tgz", + "integrity": "sha512-9nDUaEEpqZLOz5v5SUcFA0CjM4vq8YbqO0WRls+EYT7+DvxUdzDPKNCPLqGfj3YL9MsniCLCD4RFA6M95V6KMQ==", + "dependencies": { + "attr-accept": "^2.2.4", + "file-selector": "^2.1.0", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">= 10.13" + }, + "peerDependencies": { + "react": ">= 16.8 || 18.0.0" + } + }, + "node_modules/react-global-hooks": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/react-global-hooks/-/react-global-hooks-1.3.5.tgz", + "integrity": "sha512-xEvDSV6fkZ1ZAZ2qgrldw6d51awCtru6SzSVuWbrOi+tVIrGwroQLC2tdpFBYmszUCGOKi7UTuqOCYDyeJqvug==", + "peerDependencies": { + "react": "^16 || 17 || 18 || 19" + } + }, + "node_modules/react-icons": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.5.0.tgz", + "integrity": "sha512-MEFcXdkP3dLo8uumGI5xN3lDFNsRtrjbOEKDLD7yv76v4wpnEq2Lt2qeHaQOr34I/wPN3s3+N08WkQ+CW37Xiw==", + "peerDependencies": { + "react": "*" + } + }, + "node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" + }, + "node_modules/react-select": { + "version": "5.10.1", + "resolved": "https://registry.npmjs.org/react-select/-/react-select-5.10.1.tgz", + "integrity": "sha512-roPEZUL4aRZDx6DcsD+ZNreVl+fM8VsKn0Wtex1v4IazH60ILp5xhdlp464IsEAlJdXeD+BhDAFsBVMfvLQueA==", + "dependencies": { + "@babel/runtime": "^7.12.0", + "@emotion/cache": "^11.4.0", + "@emotion/react": "^11.8.1", + "@floating-ui/dom": "^1.0.1", + "@types/react-transition-group": "^4.4.0", + "memoize-one": "^6.0.0", + "prop-types": "^15.6.0", + "react-transition-group": "^4.3.0", + "use-isomorphic-layout-effect": "^1.2.0" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/react-transition-group": { + "version": "4.4.5", + "resolved": "https://registry.npmjs.org/react-transition-group/-/react-transition-group-4.4.5.tgz", + "integrity": "sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==", + "dependencies": { + "@babel/runtime": "^7.5.5", + "dom-helpers": "^5.0.1", + "loose-envify": "^1.4.0", + "prop-types": "^15.6.2" + }, + "peerDependencies": { + "react": ">=16.6.0", + "react-dom": ">=16.6.0" + } + }, + "node_modules/react-virtuoso": { + "version": "4.18.7", + "resolved": "https://registry.npmjs.org/react-virtuoso/-/react-virtuoso-4.18.7.tgz", + "integrity": "sha512-xNF5zDGEEIMB7cKwcen/pLig0YDf6OnfFrVgKFa7sHPf9fRem0CaLshyObbBcP88jzn0enavL39EgplgdyT21g==", + "license": "MIT", + "peerDependencies": { + "react": ">=16 || >=17 || >= 18 || >= 19", + "react-dom": ">=16 || >=17 || >= 18 || >=19" + } + }, + "node_modules/react-zoom-pan-pinch": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/react-zoom-pan-pinch/-/react-zoom-pan-pinch-4.0.3.tgz", + "integrity": "sha512-N2Hi6L78fFmhRra+ORpFSW7WST5x6kxpOPplIvtB0b7b+U2anpo1z1wLgaWRPS2kUSqcraRG+JgBCIlDJnqqAg==", + "license": "MIT", + "engines": { + "node": ">=8", + "npm": ">=5" + }, + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, + "node_modules/read-cache": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", + "integrity": "sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==", + "dev": true, + "dependencies": { + "pify": "^2.3.0" + } + }, + "node_modules/readable-stream": { + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz", + "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==", + "license": "MIT", + "dependencies": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/readdir-glob": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/readdir-glob/-/readdir-glob-1.1.3.tgz", + "integrity": "sha512-v05I2k7xN8zXvPD9N+z/uhXPaj0sUFCe2rcWZIpBsqxfP7xXFQ0tipAd/wjj1YxWyWtUS5IDJpOG82JKt2EAVA==", + "dependencies": { + "minimatch": "^5.1.0" + } + }, + "node_modules/readdir-glob/node_modules/minimatch": { + "version": "5.1.9", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.9.tgz", + "integrity": "sha512-7o1wEA2RyMP7Iu7GNba9vc0RWWGACJOCZBJX2GJWip0ikV+wcOsgVuY9uE8CPiyQhkGFSlhuSkZPavN7u1c2Fw==", + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dev": true, + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/require-directory": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", + "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/resolve": { + "version": "1.22.10", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz", + "integrity": "sha512-NPRy+/ncIMeDlTAsuqwKIiferiawhefFJtkNSW0qZJEqMEb+qBt/77B/jGeeek+F0uOeN05CDa6HXbbIgtVX4w==", + "dependencies": { + "is-core-module": "^2.16.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "engines": { + "node": ">=4" + } + }, + "node_modules/restore-cursor": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/restore-cursor/-/restore-cursor-4.0.0.tgz", + "integrity": "sha512-I9fPXU9geO9bHOt9pHHOhOkYerIMsmVaWB0rA2AI9ERh/+x/i7MV5HKBNrg+ljO5eoPVgCcnFuRjJ9uH6I/3eg==", + "license": "MIT", + "optional": true, + "dependencies": { + "onetime": "^5.1.0", + "signal-exit": "^3.0.2" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/restore-cursor/node_modules/signal-exit": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", + "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", + "license": "ISC", + "optional": true + }, + "node_modules/retry": { + "version": "0.12.0", + "resolved": "https://registry.npmjs.org/retry/-/retry-0.12.0.tgz", + "integrity": "sha512-9LkiTwjUh6rT555DtE9rTX+BKByPfrMzEAtnlEtdEwr3Nkffwiihqe2bWADg+OQRjt9gl6ICdmB/ZFDCGAtSow==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true, + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", + "license": "ISC", + "optional": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/brace-expansion": { + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.15.tgz", + "integrity": "sha512-EwOCDEex4quD37XhqM3omwtMoJjr//isUZz1JopUNWms+4Z2ViyM/k1YIRePpoVNnQhENnxtFjLaxNHrT7xIUg==", + "license": "MIT", + "optional": true, + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/rimraf/node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "license": "ISC", + "optional": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/minimatch": { + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", + "license": "ISC", + "optional": true, + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/rxjs": { + "version": "7.8.2", + "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.2.tgz", + "integrity": "sha512-dhKf903U/PQZY6boNNtAGdWbG85WAbjT/1xYoZIC7FAY0yWapOBQVsVrDl58W86//e1VpMNBtRV4MaXfdMySFA==", + "dev": true, + "dependencies": { + "tslib": "^2.1.0" + } + }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "license": "MIT", + "optional": true + }, + "node_modules/scheduler": { + "version": "0.27.0", + "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.27.0.tgz", + "integrity": "sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q==", + "license": "MIT" + }, + "node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/set-blocking": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", + "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==", + "license": "ISC", + "optional": true + }, + "node_modules/sharp": { + "version": "0.34.5", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.34.5.tgz", + "integrity": "sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==", + "hasInstallScript": true, + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@img/colour": "^1.0.0", + "detect-libc": "^2.1.2", + "semver": "^7.7.3" + }, + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-darwin-arm64": "0.34.5", + "@img/sharp-darwin-x64": "0.34.5", + "@img/sharp-libvips-darwin-arm64": "1.2.4", + "@img/sharp-libvips-darwin-x64": "1.2.4", + "@img/sharp-libvips-linux-arm": "1.2.4", + "@img/sharp-libvips-linux-arm64": "1.2.4", + "@img/sharp-libvips-linux-ppc64": "1.2.4", + "@img/sharp-libvips-linux-riscv64": "1.2.4", + "@img/sharp-libvips-linux-s390x": "1.2.4", + "@img/sharp-libvips-linux-x64": "1.2.4", + "@img/sharp-libvips-linuxmusl-arm64": "1.2.4", + "@img/sharp-libvips-linuxmusl-x64": "1.2.4", + "@img/sharp-linux-arm": "0.34.5", + "@img/sharp-linux-arm64": "0.34.5", + "@img/sharp-linux-ppc64": "0.34.5", + "@img/sharp-linux-riscv64": "0.34.5", + "@img/sharp-linux-s390x": "0.34.5", + "@img/sharp-linux-x64": "0.34.5", + "@img/sharp-linuxmusl-arm64": "0.34.5", + "@img/sharp-linuxmusl-x64": "0.34.5", + "@img/sharp-wasm32": "0.34.5", + "@img/sharp-win32-arm64": "0.34.5", + "@img/sharp-win32-ia32": "0.34.5", + "@img/sharp-win32-x64": "0.34.5" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "engines": { + "node": ">=8" + } + }, + "node_modules/shell-quote": { + "version": "1.8.4", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.4.tgz", + "integrity": "sha512-VsC6n6vz1ihYYyZZwX7YZSF5l5x36ca17OC+a69h94YqB7X6XLwf+5MOgynYir2SLFUbl8gIYvBo8K8RoNQ6bQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/simple-concat": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz", + "integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/simple-get": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz", + "integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "decompress-response": "^6.0.0", + "once": "^1.3.1", + "simple-concat": "^1.0.0" + } + }, + "node_modules/slice-ansi": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-8.0.0.tgz", + "integrity": "sha512-stxByr12oeeOyY2BlviTNQlYV5xOj47GirPr4yA1hE9JCtxfQN0+tVbkxwCtYDQWhEKWFHsEK48ORg5jrouCAg==", + "license": "MIT", + "optional": true, + "dependencies": { + "ansi-styles": "^6.2.3", + "is-fullwidth-code-point": "^5.1.0" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/chalk/slice-ansi?sponsor=1" + } + }, + "node_modules/slice-ansi/node_modules/is-fullwidth-code-point": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-5.1.0.tgz", + "integrity": "sha512-5XHYaSyiqADb4RnZ1Bdad6cPp8Toise4TzEjcOYDHZkTCbKgiUl7WTUCpNWHuxmDt91wnsZBc9xinNzopv3JMQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "get-east-asian-width": "^1.3.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/smart-buffer": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/smart-buffer/-/smart-buffer-4.2.0.tgz", + "integrity": "sha512-94hK0Hh8rPqQl2xXc3HsaBoOXKV20MToPkcXvwbISWLEs+64sBq5kFgn2kJDHb1Pry9yrP0dxrCI9RRci7RXKg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 6.0.0", + "npm": ">= 3.0.0" + } + }, + "node_modules/socks": { + "version": "2.8.9", + "resolved": "https://registry.npmjs.org/socks/-/socks-2.8.9.tgz", + "integrity": "sha512-LJhUYUvItdQ0LkJTmPeaEObWXAqFyfmP85x0tch/ez9cahmhlBBLbIqDFnvBnUJGagb0JbIQrkBs1wJ+yRYpEw==", + "license": "MIT", + "optional": true, + "dependencies": { + "ip-address": "^10.1.1", + "smart-buffer": "^4.2.0" + }, + "engines": { + "node": ">= 10.0.0", + "npm": ">= 3.0.0" + } + }, + "node_modules/socks-proxy-agent": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/socks-proxy-agent/-/socks-proxy-agent-6.2.1.tgz", + "integrity": "sha512-a6KW9G+6B3nWZ1yB8G7pJwL3ggLy1uTzKAgCb7ttblwqdz9fMGJUuTy3uFzEP48FAs9FLILlmzDlE2JJhVQaXQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "agent-base": "^6.0.2", + "debug": "^4.3.3", + "socks": "^2.6.2" + }, + "engines": { + "node": ">= 10" + } + }, + "node_modules/source-map": { + "version": "0.5.7", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.5.7.tgz", + "integrity": "sha512-LbrmJOMUSdEVxIKvdcJzQC+nQhe8FUZQTXQy6+I75skNgn3OoQ0DZA8YnFa7gp8tqtL3KPf1kmo0R5DoApeSGQ==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "dev": true, + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/source-map-support/node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/sqlite3": { + "version": "5.1.7", + "resolved": "https://registry.npmjs.org/sqlite3/-/sqlite3-5.1.7.tgz", + "integrity": "sha512-GGIyOiFaG+TUra3JIfkI/zGP8yZYLPQ0pl1bH+ODjiX57sPhrLU5sQJn1y9bDKZUFYkX1crlrPfSYt0BKKdkog==", + "hasInstallScript": true, + "license": "BSD-3-Clause", + "dependencies": { + "bindings": "^1.5.0", + "node-addon-api": "^7.0.0", + "prebuild-install": "^7.1.1", + "tar": "^6.1.11" + }, + "optionalDependencies": { + "node-gyp": "8.x" + }, + "peerDependencies": { + "node-gyp": "8.x" + }, + "peerDependenciesMeta": { + "node-gyp": { + "optional": true + } + } + }, + "node_modules/ssri": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/ssri/-/ssri-8.0.1.tgz", + "integrity": "sha512-97qShzy1AiyxvPNIkLWoGua7xoQzzPjQ0HAH4B0rWKo7SZ6USuPcrUiAFrws0UH8RrbWmgq3LMTObhPIHbbBeQ==", + "license": "ISC", + "optional": true, + "dependencies": { + "minipass": "^3.1.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/ssri/node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "license": "ISC", + "optional": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/stack-utils": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/stack-utils/-/stack-utils-2.0.6.tgz", + "integrity": "sha512-XlkWvfIm6RmsWtNJx+uqtKLS8eqFbxUg0ZzLXqY0caEy9l7hruX8IpiDnjsLavoBgqCCR71TqWO8MaXYheJ3RQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "escape-string-regexp": "^2.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/stack-utils/node_modules/escape-string-regexp": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz", + "integrity": "sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/state-local": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/state-local/-/state-local-1.0.7.tgz", + "integrity": "sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==" + }, + "node_modules/streamx": { + "version": "2.22.1", + "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.22.1.tgz", + "integrity": "sha512-znKXEBxfatz2GBNK02kRnCXjV+AA4kjZIUxeWSr3UGirZMJfTE9uiwKHobnbgxWyL/JWro8tTq+vOqAK1/qbSA==", + "dependencies": { + "fast-fifo": "^1.3.2", + "text-decoder": "^1.1.0" + }, + "optionalDependencies": { + "bare-events": "^2.2.0" + } + }, + "node_modules/string_decoder": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", + "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", + "license": "MIT", + "dependencies": { + "safe-buffer": "~5.2.0" + } + }, + "node_modules/string-width": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", + "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", + "dependencies": { + "eastasianwidth": "^0.2.0", + "emoji-regex": "^9.2.2", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/string-width-cjs": { + "name": "string-width", + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==" + }, + "node_modules/string-width-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.2.0.tgz", + "integrity": "sha512-yDPMNjp4WyfYBkHnjIRLfca1i6KMyGCtsVgoKe/z1+6vukgaENdgGBZt+ZmKPc4gavvEZ5OgHfHdrazhgNyG7w==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.2.2" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/strip-ansi-cjs": { + "name": "strip-ansi", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==", + "dev": true, + "engines": { + "node": ">=4" + } + }, + "node_modules/strip-json-comments": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz", + "integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/styled-jsx": { + "version": "5.1.6", + "resolved": "https://registry.npmjs.org/styled-jsx/-/styled-jsx-5.1.6.tgz", + "integrity": "sha512-qSVyDTeMotdvQYoHWLNGwRFJHC+i+ZvdBRYosOFgC+Wg1vx4frN2/RG/NA7SYqqvKNLf39P2LSRA2pu6n0XYZA==", + "dependencies": { + "client-only": "0.0.1" + }, + "engines": { + "node": ">= 12.0.0" + }, + "peerDependencies": { + "react": ">= 16.8.0 || 17.x.x || ^18.0.0-0 || ^19.0.0-0" + }, + "peerDependenciesMeta": { + "@babel/core": { + "optional": true + }, + "babel-plugin-macros": { + "optional": true + } + } + }, + "node_modules/stylis": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.2.0.tgz", + "integrity": "sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==" + }, + "node_modules/sucrase": { + "version": "3.35.0", + "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.0.tgz", + "integrity": "sha512-8EbVDiu9iN/nESwxeSxDKe0dunta1GOlHufmSSXxMD2z2/tMZpDMpvXQGsc+ajGo8y2uYUmixaSRUc/QPoQ0GA==", + "dev": true, + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.2", + "commander": "^4.0.0", + "glob": "^10.3.10", + "lines-and-columns": "^1.1.6", + "mz": "^2.7.0", + "pirates": "^4.0.1", + "ts-interface-checker": "^0.1.9" + }, + "bin": { + "sucrase": "bin/sucrase", + "sucrase-node": "bin/sucrase-node" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/systeminformation": { + "version": "5.31.7", + "resolved": "https://registry.npmjs.org/systeminformation/-/systeminformation-5.31.7.tgz", + "integrity": "sha512-/8NC53e5nP9nmhn42/ncdOkyJnOoue/Vy+tJOyUGd1Yv66G069wK4rrziwhrqDETgk78CudTQupw5z19S5uoZw==", + "license": "MIT", + "os": [ + "darwin", + "linux", + "win32", + "freebsd", + "openbsd", + "netbsd", + "sunos", + "android" + ], + "bin": { + "systeminformation": "lib/cli.js" + }, + "engines": { + "node": ">=8.0.0" + }, + "funding": { + "type": "Buy me a coffee", + "url": "https://www.buymeacoffee.com/systeminfo" + } + }, + "node_modules/tabbable": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/tabbable/-/tabbable-6.2.0.tgz", + "integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==" + }, + "node_modules/tagged-tag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/tagged-tag/-/tagged-tag-1.0.0.tgz", + "integrity": "sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/tailwindcss": { + "version": "3.4.17", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.17.tgz", + "integrity": "sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==", + "dev": true, + "dependencies": { + "@alloc/quick-lru": "^5.2.0", + "arg": "^5.0.2", + "chokidar": "^3.6.0", + "didyoumean": "^1.2.2", + "dlv": "^1.1.3", + "fast-glob": "^3.3.2", + "glob-parent": "^6.0.2", + "is-glob": "^4.0.3", + "jiti": "^1.21.6", + "lilconfig": "^3.1.3", + "micromatch": "^4.0.8", + "normalize-path": "^3.0.0", + "object-hash": "^3.0.0", + "picocolors": "^1.1.1", + "postcss": "^8.4.47", + "postcss-import": "^15.1.0", + "postcss-js": "^4.0.1", + "postcss-load-config": "^4.0.2", + "postcss-nested": "^6.2.0", + "postcss-selector-parser": "^6.1.2", + "resolve": "^1.22.8", + "sucrase": "^3.35.0" + }, + "bin": { + "tailwind": "lib/cli.js", + "tailwindcss": "lib/cli.js" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tar": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.1.tgz", + "integrity": "sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==", + "license": "ISC", + "dependencies": { + "chownr": "^2.0.0", + "fs-minipass": "^2.0.0", + "minipass": "^5.0.0", + "minizlib": "^2.1.1", + "mkdirp": "^1.0.3", + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/tar-fs": { + "version": "2.1.4", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.4.tgz", + "integrity": "sha512-mDAjwmZdh7LTT6pNleZ05Yt65HC3E+NiQzl672vQG38jIrehtJk/J3mNwIg+vShQPcLF/LV7CMnDW6vjj6sfYQ==", + "dependencies": { + "chownr": "^1.1.1", + "mkdirp-classic": "^0.5.2", + "pump": "^3.0.0", + "tar-stream": "^2.1.4" + } + }, + "node_modules/tar-fs/node_modules/chownr": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz", + "integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==", + "license": "ISC" + }, + "node_modules/tar-stream": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz", + "integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==", + "license": "MIT", + "dependencies": { + "bl": "^4.0.3", + "end-of-stream": "^1.4.1", + "fs-constants": "^1.0.0", + "inherits": "^2.0.3", + "readable-stream": "^3.1.1" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/tar/node_modules/minipass": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", + "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", + "license": "ISC", + "engines": { + "node": ">=8" + } + }, + "node_modules/terminal-size": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/terminal-size/-/terminal-size-4.0.1.tgz", + "integrity": "sha512-avMLDQpUI9I5XFrklECw1ZEUPJhqzcwSWsyyI8blhRLT+8N1jLJWLWWYQpB2q2xthq8xDvjZPISVh53T/+CLYQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/text-decoder": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/text-decoder/-/text-decoder-1.2.3.tgz", + "integrity": "sha512-3/o9z3X0X0fTupwsYvR03pJ/DjWuqqrfwBgTQzdWDiQSm9KitAyz/9WqsT2JQW7KV2m+bC2ol/zqpW37NHxLaA==", + "dependencies": { + "b4a": "^1.6.4" + } + }, + "node_modules/thenify": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/thenify/-/thenify-3.3.1.tgz", + "integrity": "sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==", + "dev": true, + "dependencies": { + "any-promise": "^1.0.0" + } + }, + "node_modules/thenify-all": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/thenify-all/-/thenify-all-1.6.0.tgz", + "integrity": "sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==", + "dev": true, + "dependencies": { + "thenify": ">= 3.1.0 < 4" + }, + "engines": { + "node": ">=0.8" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/tree-kill": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/tree-kill/-/tree-kill-1.2.2.tgz", + "integrity": "sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==", + "dev": true, + "bin": { + "tree-kill": "cli.js" + } + }, + "node_modules/ts-interface-checker": { + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", + "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==", + "dev": true + }, + "node_modules/ts-node-dev": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ts-node-dev/-/ts-node-dev-2.0.0.tgz", + "integrity": "sha512-ywMrhCfH6M75yftYvrvNarLEY+SUXtUvU8/0Z6llrHQVBx12GiFk5sStF8UdfE/yfzk9IAq7O5EEbTQsxlBI8w==", + "dev": true, + "dependencies": { + "chokidar": "^3.5.1", + "dynamic-dedupe": "^0.3.0", + "minimist": "^1.2.6", + "mkdirp": "^1.0.4", + "resolve": "^1.0.0", + "rimraf": "^2.6.1", + "source-map-support": "^0.5.12", + "tree-kill": "^1.2.2", + "ts-node": "^10.4.0", + "tsconfig": "^7.0.0" + }, + "bin": { + "ts-node-dev": "lib/bin.js", + "tsnd": "lib/bin.js" + }, + "engines": { + "node": ">=0.8.0" + }, + "peerDependencies": { + "node-notifier": "*", + "typescript": "*" + }, + "peerDependenciesMeta": { + "node-notifier": { + "optional": true + } + } + }, + "node_modules/ts-node-dev/node_modules/arg": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", + "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==", + "dev": true + }, + "node_modules/ts-node-dev/node_modules/brace-expansion": { + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.15.tgz", + "integrity": "sha512-EwOCDEex4quD37XhqM3omwtMoJjr//isUZz1JopUNWms+4Z2ViyM/k1YIRePpoVNnQhENnxtFjLaxNHrT7xIUg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/ts-node-dev/node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "dev": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/ts-node-dev/node_modules/minimatch": { + "version": "3.1.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz", + "integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/ts-node-dev/node_modules/rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + } + }, + "node_modules/ts-node-dev/node_modules/ts-node": { + "version": "10.9.2", + "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz", + "integrity": "sha512-f0FFpIdcHgn8zcPSbf1dRevwt047YMnaiJM3u2w2RewrB+fob/zePZcrOyQoLMMO7aBIddLcQIEK5dYjkLnGrQ==", + "dev": true, + "dependencies": { + "@cspotcode/source-map-support": "^0.8.0", + "@tsconfig/node10": "^1.0.7", + "@tsconfig/node12": "^1.0.7", + "@tsconfig/node14": "^1.0.0", + "@tsconfig/node16": "^1.0.2", + "acorn": "^8.4.1", + "acorn-walk": "^8.1.1", + "arg": "^4.1.0", + "create-require": "^1.1.0", + "diff": "^4.0.1", + "make-error": "^1.1.1", + "v8-compile-cache-lib": "^3.0.1", + "yn": "3.1.1" + }, + "bin": { + "ts-node": "dist/bin.js", + "ts-node-cwd": "dist/bin-cwd.js", + "ts-node-esm": "dist/bin-esm.js", + "ts-node-script": "dist/bin-script.js", + "ts-node-transpile-only": "dist/bin-transpile.js", + "ts-script": "dist/bin-script-deprecated.js" + }, + "peerDependencies": { + "@swc/core": ">=1.2.50", + "@swc/wasm": ">=1.2.50", + "@types/node": "*", + "typescript": ">=2.7" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "@swc/wasm": { + "optional": true + } + } + }, + "node_modules/tsconfig": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/tsconfig/-/tsconfig-7.0.0.tgz", + "integrity": "sha512-vZXmzPrL+EmC4T/4rVlT2jNVMWCi/O4DIiSj3UHg1OE5kCKbk4mfrXc6dZksLgRM/TZlKnousKH9bbTazUWRRw==", + "dev": true, + "dependencies": { + "@types/strip-bom": "^3.0.0", + "@types/strip-json-comments": "0.0.30", + "strip-bom": "^3.0.0", + "strip-json-comments": "^2.0.0" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==" + }, + "node_modules/tunnel-agent": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz", + "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==", + "license": "Apache-2.0", + "dependencies": { + "safe-buffer": "^5.0.1" + }, + "engines": { + "node": "*" + } + }, + "node_modules/type-fest": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-5.5.0.tgz", + "integrity": "sha512-PlBfpQwiUvGViBNX84Yxwjsdhd1TUlXr6zjX7eoirtCPIr08NAmxwa+fcYBTeRQxHo9YC9wwF3m9i700sHma8g==", + "license": "(MIT OR CC0-1.0)", + "optional": true, + "dependencies": { + "tagged-tag": "^1.0.0" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/typescript": { + "version": "5.7.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.3.tgz", + "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", + "devOptional": true, + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "6.19.8", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz", + "integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==", + "dev": true + }, + "node_modules/unique-filename": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/unique-filename/-/unique-filename-1.1.1.tgz", + "integrity": "sha512-Vmp0jIp2ln35UTXuryvjzkjGdRyf9b2lTXuSYUiPmzRcl3FDtYqAwOnTJkAngD9SWhnoJzDbTKwaOrZ+STtxNQ==", + "license": "ISC", + "optional": true, + "dependencies": { + "unique-slug": "^2.0.0" + } + }, + "node_modules/unique-slug": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/unique-slug/-/unique-slug-2.0.2.tgz", + "integrity": "sha512-zoWr9ObaxALD3DOPfjPSqxt4fnZiWblxHIgeWqW8x7UqDzEtHEQLzji2cuJYQFCU6KmoJikOYAZlrTHHebjx2w==", + "license": "ISC", + "optional": true, + "dependencies": { + "imurmurhash": "^0.1.4" + } + }, + "node_modules/uplot": { + "version": "1.6.32", + "resolved": "https://registry.npmjs.org/uplot/-/uplot-1.6.32.tgz", + "integrity": "sha512-KIMVnG68zvu5XXUbC4LQEPnhwOxBuLyW1AHtpm6IKTXImkbLgkMy+jabjLgSLMasNuGGzQm/ep3tOkyTxpiQIw==", + "license": "MIT" + }, + "node_modules/use-isomorphic-layout-effect": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/use-isomorphic-layout-effect/-/use-isomorphic-layout-effect-1.2.0.tgz", + "integrity": "sha512-q6ayo8DWoPZT0VdG4u3D3uxcgONP3Mevx2i2b0434cwWBoL+aelL1DzkXI6w3PhTZzUeR2kaVlZn70iCiseP6w==", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" + }, + "node_modules/uuid": { + "version": "11.1.1", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.1.tgz", + "integrity": "sha512-vIYxrBCC/N/K+Js3qSN88go7kIfNPssr/hHCesKCQNAjmgvYS2oqr69kIufEG+O4+PfezOH4EbIeHCfFov8ZgQ==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/esm/bin/uuid" + } + }, + "node_modules/v8-compile-cache-lib": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", + "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==", + "dev": true + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/wide-align": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/wide-align/-/wide-align-1.1.5.tgz", + "integrity": "sha512-eDMORYaPNZ4sQIuuYPDHdQvf4gyCF9rEEV/yPxGfwPkRodwEgiMUUXTx/dex+Me0wxx53S+NgUHaP7y3MGlDmg==", + "license": "ISC", + "optional": true, + "dependencies": { + "string-width": "^1.0.2 || 2 || 3 || 4" + } + }, + "node_modules/wide-align/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/wide-align/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT", + "optional": true + }, + "node_modules/wide-align/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "optional": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wide-align/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "optional": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/widest-line": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/widest-line/-/widest-line-6.0.0.tgz", + "integrity": "sha512-U89AsyEeAsyoF0zVJBkG9zBgekjgjK7yk9sje3F4IQpXBJ10TF6ByLlIfjMhcmHMJgHZI4KHt4rdNfktzxIAMA==", + "license": "MIT", + "optional": true, + "dependencies": { + "string-width": "^8.1.0" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/widest-line/node_modules/string-width": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-8.2.0.tgz", + "integrity": "sha512-6hJPQ8N0V0P3SNmP6h2J99RLuzrWz2gvT7VnK5tKvrNqJoyS9W4/Fb8mo31UiPvy00z7DQXkP2hnKBVav76thw==", + "license": "MIT", + "optional": true, + "dependencies": { + "get-east-asian-width": "^1.5.0", + "strip-ansi": "^7.1.2" + }, + "engines": { + "node": ">=20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/wrap-ansi": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", + "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", + "dependencies": { + "ansi-styles": "^6.1.0", + "string-width": "^5.0.1", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs": { + "name": "wrap-ansi", + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==" + }, + "node_modules/wrap-ansi-cjs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "license": "ISC" + }, + "node_modules/ws": { + "version": "8.21.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.21.0.tgz", + "integrity": "sha512-Vsp28b7DRcimFQvrqu2Wek3z1iYxDCWqHYB8Qsnk/S4RfaCQzPGPyBNuVjJV3cd6UiKtUtp6sNM77gWvzcCH+g==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xtend": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", + "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", + "dev": true, + "engines": { + "node": ">=0.4" + } + }, + "node_modules/y18n": { + "version": "5.0.8", + "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", + "integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==", + "dev": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "license": "ISC" + }, + "node_modules/yaml": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.9.0.tgz", + "integrity": "sha512-2AvhNX3mb8zd6Zy7INTtSpl1F15HW6Wnqj0srWlkKLcpYl/gMIMJiyuGq2KeI2YFxUPjdlB+3Lc10seMLtL4cA==", + "license": "ISC", + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14.6" + }, + "funding": { + "url": "https://github.com/sponsors/eemeli" + } + }, + "node_modules/yargs": { + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", + "dev": true, + "dependencies": { + "cliui": "^8.0.1", + "escalade": "^3.1.1", + "get-caller-file": "^2.0.5", + "require-directory": "^2.1.1", + "string-width": "^4.2.3", + "y18n": "^5.0.5", + "yargs-parser": "^21.1.1" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/yargs-parser": { + "version": "21.1.1", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", + "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", + "dev": true, + "engines": { + "node": ">=12" + } + }, + "node_modules/yargs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "node_modules/yargs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yn": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz", + "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/yoga-layout": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/yoga-layout/-/yoga-layout-3.2.1.tgz", + "integrity": "sha512-0LPOt3AxKqMdFBZA3HBAt/t/8vIKq7VaQYbuA8WxCgung+p9TVyKRYdpvCb80HcdTN2NkbIKbhNwKUfm3tQywQ==", + "license": "MIT", + "optional": true + }, + "node_modules/zip-stream": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/zip-stream/-/zip-stream-6.0.1.tgz", + "integrity": "sha512-zK7YHHz4ZXpW89AHXUPbQVGKI7uvkd3hzusTdotCg1UxyaVtg0zFJSTfW/Dq5f7OBBVnq6cZIaC8Ti4hb6dtCA==", + "dependencies": { + "archiver-utils": "^5.0.0", + "compress-commons": "^6.0.2", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/zip-stream/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/zip-stream/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + } + } +} diff --git a/ai-toolkit/ui/package.json b/ai-toolkit/ui/package.json new file mode 100644 index 0000000000000000000000000000000000000000..86d990a4c229db0603ca1759cd00502b96e6e8e2 --- /dev/null +++ b/ai-toolkit/ui/package.json @@ -0,0 +1,56 @@ +{ + "name": "ai-toolkit-ui", + "version": "0.1.0", + "private": true, + "scripts": { + "dev": "concurrently -k -n WORKER,UI \"ts-node-dev --project tsconfig.worker.json --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"", + "build": "tsc -p tsconfig.worker.json && next build", + "start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/cron/worker.js\" \"next start --port 8675\"", + "build_and_start": "npm install && npm run update_db && npm run build && npm run start", + "lint": "next lint", + "update_db": "npx prisma generate && npx prisma db push", + "format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\"" + }, + "dependencies": { + "@headlessui/react": "^2.2.0", + "@monaco-editor/react": "^4.7.0", + "@prisma/client": "^6.3.1", + "archiver": "^7.0.1", + "axios": "^1.7.9", + "classnames": "^2.5.1", + "lucide-react": "^0.475.0", + "next": "^15.5.9", + "node-cache": "^5.1.2", + "prisma": "^6.3.1", + "react": "^19.2.0", + "react-dom": "^19.2.0", + "react-dropzone": "^14.3.5", + "react-global-hooks": "^1.3.5", + "react-icons": "^5.5.0", + "react-select": "^5.10.1", + "react-virtuoso": "^4.18.7", + "react-zoom-pan-pinch": "^4.0.3", + "sqlite3": "^5.1.7", + "systeminformation": "^5.27.11", + "uplot": "^1.6.32", + "uuid": "^11.1.0", + "yaml": "^2.7.0" + }, + "devDependencies": { + "@types/archiver": "^6.0.3", + "@types/node": "^20", + "@types/react": "^19", + "@types/react-dom": "^19", + "concurrently": "^9.1.2", + "postcss": "^8", + "prettier": "^3.5.1", + "prettier-basic": "^1.0.0", + "tailwindcss": "^3.4.1", + "ts-node-dev": "^2.0.0", + "typescript": "^5" + }, + "optionalDependencies": { + "macstats": "^4.2.0" + }, + "prettier": "prettier-basic" +} diff --git a/ai-toolkit/ui/postcss.config.mjs b/ai-toolkit/ui/postcss.config.mjs new file mode 100644 index 0000000000000000000000000000000000000000..1a69fd2a450afc3bf47e08b22c149190df0ffdb4 --- /dev/null +++ b/ai-toolkit/ui/postcss.config.mjs @@ -0,0 +1,8 @@ +/** @type {import('postcss-load-config').Config} */ +const config = { + plugins: { + tailwindcss: {}, + }, +}; + +export default config; diff --git a/ai-toolkit/ui/prisma/schema.prisma b/ai-toolkit/ui/prisma/schema.prisma new file mode 100644 index 0000000000000000000000000000000000000000..79c90fbf1b92dadb9945cef7bf022e8f0d8b8f9b --- /dev/null +++ b/ai-toolkit/ui/prisma/schema.prisma @@ -0,0 +1,49 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "sqlite" + url = "file:../../aitk_db.db" +} + +model Settings { + id Int @id @default(autoincrement()) + key String @unique + value String +} + +model Queue { + id Int @id @default(autoincrement()) + gpu_ids String @unique + is_running Boolean @default(false) + + @@index([gpu_ids]) +} + +model Job { + id String @id @default(uuid()) + name String @unique + gpu_ids String + job_config String // JSON string + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + status String @default("stopped") + stop Boolean @default(false) + return_to_queue Boolean @default(false) // same as stop, but will be set to 'queued' when stopped + step Int @default(0) + total_steps Int? + info String @default("") + speed_string String @default("") + queue_position Int @default(0) + pid Int? + job_type String @default("train") // 'train', 'caption' + job_ref String? // can be used for anything for special jobs, like dataset path for caption jobs + save_now Boolean @default(false) // if true, the job will be saved on the next step + + @@index([status]) + @@index([gpu_ids]) + @@index([job_type]) + @@index([job_ref]) + +} diff --git a/ai-toolkit/ui/public/file.svg b/ai-toolkit/ui/public/file.svg new file mode 100644 index 0000000000000000000000000000000000000000..004145cddf3f9db91b57b9cb596683c8eb420862 --- /dev/null +++ b/ai-toolkit/ui/public/file.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-toolkit/ui/public/globe.svg b/ai-toolkit/ui/public/globe.svg new file mode 100644 index 0000000000000000000000000000000000000000..567f17b0d7c7fb662c16d4357dd74830caf2dccb --- /dev/null +++ b/ai-toolkit/ui/public/globe.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-toolkit/ui/public/next.svg b/ai-toolkit/ui/public/next.svg new file mode 100644 index 0000000000000000000000000000000000000000..5174b28c565c285e3e312ec5178be64fbeca8398 --- /dev/null +++ b/ai-toolkit/ui/public/next.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-toolkit/ui/public/vercel.svg b/ai-toolkit/ui/public/vercel.svg new file mode 100644 index 0000000000000000000000000000000000000000..77053960334e2e34dc584dea8019925c3b4ccca9 --- /dev/null +++ b/ai-toolkit/ui/public/vercel.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-toolkit/ui/public/window.svg b/ai-toolkit/ui/public/window.svg new file mode 100644 index 0000000000000000000000000000000000000000..b2b2a44f6ebc70c450043c05a002e7a93ba5d651 --- /dev/null +++ b/ai-toolkit/ui/public/window.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-toolkit/ui/src/app/api/audio/art/[...audioPath]/route.ts b/ai-toolkit/ui/src/app/api/audio/art/[...audioPath]/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..fbc650471c96a2686bf030344d9c40bbd4719527 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/audio/art/[...audioPath]/route.ts @@ -0,0 +1,187 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings'; + +/** + * Serves embedded album art from an MP3 file's ID3v2 tag. + * Reads only the tag header from disk (no full-file buffering). + * Returns the raw image bytes with correct Content-Type. + */ + +function synchsafeToInt(b0: number, b1: number, b2: number, b3: number) { + return ((b0 & 0x7f) << 21) | ((b1 & 0x7f) << 14) | ((b2 & 0x7f) << 7) | (b3 & 0x7f); +} + +function deUnsync(bytes: Buffer) { + const out: number[] = []; + for (let i = 0; i < bytes.length; i++) { + out.push(bytes[i]); + if (bytes[i] === 0xff && i + 1 < bytes.length && bytes[i + 1] === 0x00) i += 1; + } + return Buffer.from(out); +} + +function readNullTerminated(buf: Buffer, start: number, wide: boolean): { text: string; next: number } { + if (wide) { + let i = start; + while (i + 1 < buf.length && !(buf[i] === 0 && buf[i + 1] === 0)) i += 2; + return { text: buf.slice(start, i).toString('utf16le'), next: i + 2 }; + } + let i = start; + while (i < buf.length && buf[i] !== 0) i++; + return { text: buf.slice(start, i).toString('latin1'), next: i + 1 }; +} + +type ArtResult = { mime: string; data: Buffer } | null; + +function extractArtFromTag(buf: Buffer): ArtResult { + if (buf.length < 10) return null; + if (buf[0] !== 0x49 || buf[1] !== 0x44 || buf[2] !== 0x33) return null; // "ID3" + + const verMajor = buf[3]; // 2, 3, or 4 + const flags = buf[5]; + const tagSize = synchsafeToInt(buf[6], buf[7], buf[8], buf[9]); + const tagEnd = Math.min(10 + tagSize, buf.length); + + let tagData = buf.slice(10, tagEnd); + if ((flags & 0x80) !== 0) tagData = deUnsync(tagData); + + let offset = 0; + + // Skip extended header + if ((verMajor === 3 || verMajor === 4) && (flags & 0x40) !== 0 && tagData.length >= 4) { + const extSize = + verMajor === 4 + ? synchsafeToInt(tagData[0], tagData[1], tagData[2], tagData[3]) + : (tagData[0] << 24) | (tagData[1] << 16) | (tagData[2] << 8) | tagData[3]; + offset += 4 + Math.max(0, extSize); + } + + while (offset < tagData.length) { + if (tagData[offset] === 0x00) break; + + if (verMajor === 2) { + // ID3v2.2: 3-byte frame ID, 3-byte size + if (offset + 6 > tagData.length) break; + const id = tagData.slice(offset, offset + 3).toString('latin1'); + const size = (tagData[offset + 3] << 16) | (tagData[offset + 4] << 8) | tagData[offset + 5]; + offset += 6; + if (!id.trim() || size <= 0 || offset + size > tagData.length) break; + + if (id === 'PIC' && size > 6) { + const frame = tagData.slice(offset, offset + size); + const fmt = frame.slice(1, 4).toString('latin1').toLowerCase(); + const mime = fmt === 'png' ? 'image/png' : 'image/jpeg'; + // skip: encoding(1) + format(3) + pictureType(1) = 5, then null-terminated description + let p = 5; + const enc = frame[0]; + const wide = enc === 1 || enc === 2; + const desc = readNullTerminated(frame as any, p, wide); + p = desc.next; + if (p < frame.length) { + const img = frame.slice(p); + if (img.length > 64) return { mime, data: Buffer.from(img) }; + } + } + offset += size; + } else { + // ID3v2.3/v2.4: 4-byte frame ID, 4-byte size, 2-byte flags + if (offset + 10 > tagData.length) break; + const id = tagData.slice(offset, offset + 4).toString('latin1'); + let size = + verMajor === 4 + ? synchsafeToInt(tagData[offset + 4], tagData[offset + 5], tagData[offset + 6], tagData[offset + 7]) + : (tagData[offset + 4] << 24) | + (tagData[offset + 5] << 16) | + (tagData[offset + 6] << 8) | + tagData[offset + 7]; + const flag2 = tagData[offset + 9]; + offset += 10; + if (!id.trim() || size <= 0 || offset + size > tagData.length) break; + + if (id === 'APIC') { + let frame = tagData.slice(offset, offset + size); + if (verMajor === 4 && (flag2 & 0x02) !== 0) frame = deUnsync(frame); + + const enc = frame[0]; + // mime type: null-terminated latin1 + const mimeZ = readNullTerminated(frame as any, 1, false); + const mime = mimeZ.text || 'image/jpeg'; + let p = mimeZ.next; + if (p < frame.length) p += 1; // picture type byte + const wide = enc === 1 || enc === 2; + const desc = readNullTerminated(frame as any, p, wide); + p = desc.next; + if (p < frame.length) { + const img = frame.slice(p); + if (img.length > 64) return { mime, data: Buffer.from(img) }; + } + } + offset += size; + } + } + return null; +} + +export async function GET(request: NextRequest, { params }: { params: { audioPath: string } }) { + const { audioPath } = await params; + try { + const filepath = decodeURIComponent(audioPath); + + // Security check + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const dataRoot = await getDataRoot(); + const allowedDirs = [datasetRoot, trainingRoot, dataRoot]; + // Resolve so `..` segments collapse, then verify still under an allowed root. + // Substring `.includes('..')` false-positives on filenames containing `..` as text. + const resolved = path.resolve(filepath); + const isAllowed = allowedDirs.some(d => resolved === d || resolved.startsWith(d + path.sep)); + if (!isAllowed) { + return new NextResponse('Access denied', { status: 403 }); + } + + const stat = await fs.promises.stat(resolved).catch(() => null); + if (!stat || !stat.isFile()) { + return new NextResponse('File not found', { status: 404 }); + } + + // Read only the ID3 tag (first min(tagSize, 4MB) bytes). + // First read 10 bytes to get tag size, then read the full tag. + const fd = await fs.promises.open(resolved, 'r'); + try { + const headerBuf = Buffer.alloc(10); + await fd.read(headerBuf, 0, 10, 0); + + if (headerBuf[0] !== 0x49 || headerBuf[1] !== 0x44 || headerBuf[2] !== 0x33) { + return new NextResponse('No ID3 tag', { status: 404 }); + } + + const tagSize = synchsafeToInt(headerBuf[6], headerBuf[7], headerBuf[8], headerBuf[9]); + const totalRead = Math.min(10 + tagSize, 4_000_000); + + const tagBuf = Buffer.alloc(totalRead); + await fd.read(tagBuf, 0, totalRead, 0); + + const art = extractArtFromTag(tagBuf); + if (!art) { + return new NextResponse('No album art found', { status: 404 }); + } + + return new NextResponse(art.data as any, { + headers: { + 'Content-Type': art.mime, + 'Content-Length': String(art.data.length), + 'Cache-Control': 'public, max-age=604800, immutable', + }, + }); + } finally { + await fd.close(); + } + } catch (error) { + console.error('Error extracting album art:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/auth/route.ts b/ai-toolkit/ui/src/app/api/auth/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..1dc229739fbbeaabf307e3be544dd7e2bc8ab66f --- /dev/null +++ b/ai-toolkit/ui/src/app/api/auth/route.ts @@ -0,0 +1,6 @@ +import { NextResponse } from 'next/server'; + +export async function GET() { + // if this gets hit, auth has already been verified + return NextResponse.json({ isAuthenticated: true }); +} diff --git a/ai-toolkit/ui/src/app/api/caption/get/route.ts b/ai-toolkit/ui/src/app/api/caption/get/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..887e70fd262f97cc455b79f63fdfbb7d1d4c52dd --- /dev/null +++ b/ai-toolkit/ui/src/app/api/caption/get/route.ts @@ -0,0 +1,64 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +function isUnderRoot(filepath: string, root: string): boolean { + const resolved = path.resolve(filepath); + return resolved === root || resolved.startsWith(root + path.sep); +} + +export async function POST(request: NextRequest) { + let body; + try { + body = await request.json(); + } catch { + // Client aborted the request before body was fully sent + return new NextResponse(null, { status: 499 }); + } + + if (request.signal.aborted) { + return new NextResponse(null, { status: 499 }); + } + + const { imgPath, ext } = body; + console.log('Received POST request for caption:', imgPath); + try { + // Decode the path + const filepath = imgPath; + console.log('Decoded image path:', filepath); + + // caption name is the filepath without extension but with the caption extension (default txt) + const captionExt = ((ext || 'txt') as string).replace(/^\.+/, '').trim() || 'txt'; + const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.' + captionExt; + + // Get allowed directories + const allowedDir = await getDatasetsRoot(); + + // Security check: resolve so `..` segments collapse, then verify it's still + // under the allowed root. Substring `.includes('..')` would false-positive + // on filenames that contain `..` as text (e.g. an ellipsis in a filename). + const isAllowed = isUnderRoot(filepath, allowedDir); + + if (!isAllowed) { + console.warn(`Access denied: ${filepath} not in ${allowedDir}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(captionPath)) { + // send back blank string if caption file does not exist + return new NextResponse(''); + } + + // Read caption file + const caption = fs.readFileSync(captionPath, 'utf-8'); + + // Return caption + return new NextResponse(caption); + } catch (error) { + console.error('Error getting caption:', error); + return new NextResponse('Error getting caption', { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/caption/getBatch/route.ts b/ai-toolkit/ui/src/app/api/caption/getBatch/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..5968fbe1de102fa42d391aada86d57b1c6d664ce --- /dev/null +++ b/ai-toolkit/ui/src/app/api/caption/getBatch/route.ts @@ -0,0 +1,46 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +function isUnderRoot(filepath: string, root: string): boolean { + const resolved = path.resolve(filepath); + return resolved === root || resolved.startsWith(root + path.sep); +} + +export async function POST(request: NextRequest) { + let body; + try { + body = await request.json(); + } catch { + return new NextResponse(null, { status: 499 }); + } + + if (request.signal.aborted) { + return new NextResponse(null, { status: 499 }); + } + + const { imgPaths, ext } = body as { imgPaths?: string[]; ext?: string }; + if (!Array.isArray(imgPaths)) { + return NextResponse.json({ error: 'imgPaths must be an array' }, { status: 400 }); + } + + const captionExt = ((ext || 'txt') as string).replace(/^\.+/, '').trim() || 'txt'; + const allowedDir = await getDatasetsRoot(); + const captions: Record = {}; + + for (const imgPath of imgPaths) { + if (typeof imgPath !== 'string') continue; + if (!isUnderRoot(imgPath, allowedDir)) continue; + + const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.' + captionExt; + try { + captions[imgPath] = fs.existsSync(captionPath) ? fs.readFileSync(captionPath, 'utf-8') : ''; + } catch { + captions[imgPath] = ''; + } + } + + return NextResponse.json({ captions }); +} diff --git a/ai-toolkit/ui/src/app/api/cpu/route.ts b/ai-toolkit/ui/src/app/api/cpu/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..61fd2daf418b0ac99f543f0eec2820a058dc7718 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/cpu/route.ts @@ -0,0 +1,66 @@ +import { NextResponse } from 'next/server'; +import si from 'systeminformation'; +import { createRequire } from 'module'; +import os from 'os'; +import { CpuInfo } from '@/types'; + +const isMac = os.platform() === 'darwin'; + +export async function GET() { + try { + const cpuInfoRaw = await si.cpu(); + let cpuInfo: CpuInfo; + + if (isMac) { + try { + const nativeRequire = createRequire(import.meta.url); + const ms = nativeRequire('macstats') as any; + const ramData = ms.getRAMUsageSync(); + const cpuData = ms.getCpuDataSync(); + + cpuInfo = { + name: `${cpuInfoRaw.manufacturer} ${cpuInfoRaw.brand}`, + cores: cpuInfoRaw.cores, + temperature: cpuData.temperature || 0, + totalMemory: ramData.total / (1024 * 1024), + availableMemory: ramData.free / (1024 * 1024), + freeMemory: ramData.free / (1024 * 1024), + currentLoad: (await si.currentLoad()).currentLoad || 0, + }; + } catch { + // Fallback to systeminformation if macstats fails + const memoryData = await si.mem(); + cpuInfo = { + name: `${cpuInfoRaw.manufacturer} ${cpuInfoRaw.brand}`, + cores: cpuInfoRaw.cores, + temperature: (await si.cpuTemperature()).main || 0, + totalMemory: memoryData.total / (1024 * 1024), + availableMemory: memoryData.available / (1024 * 1024), + freeMemory: memoryData.free / (1024 * 1024), + currentLoad: (await si.currentLoad()).currentLoad || 0, + }; + } + } else { + const memoryData = await si.mem(); + cpuInfo = { + name: `${cpuInfoRaw.manufacturer} ${cpuInfoRaw.brand}`, + cores: cpuInfoRaw.cores, + temperature: (await si.cpuTemperature()).main || 0, + totalMemory: memoryData.total / (1024 * 1024), + availableMemory: memoryData.available / (1024 * 1024), + freeMemory: memoryData.free / (1024 * 1024), + currentLoad: (await si.currentLoad()).currentLoad || 0, + }; + } + + return NextResponse.json(cpuInfo); + } catch (error) { + console.error('Error fetching CPU stats:', error); + return NextResponse.json( + { + error: `Failed to fetch CPU stats: ${error instanceof Error ? error.message : String(error)}`, + }, + { status: 500 }, + ); + } +} diff --git a/ai-toolkit/ui/src/app/api/datasets/create/route.tsx b/ai-toolkit/ui/src/app/api/datasets/create/route.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ac4d290adacac3df9c5e00354138be4e45f52379 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/datasets/create/route.tsx @@ -0,0 +1,25 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + let { name } = body; + // clean name by making lower case, removing special characters, and replacing spaces with underscores + name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_'); + + let datasetsPath = await getDatasetsRoot(); + let datasetPath = path.join(datasetsPath, name); + + // if folder doesnt exist, create it + if (!fs.existsSync(datasetPath)) { + fs.mkdirSync(datasetPath); + } + + return NextResponse.json({ success: true, name: name }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/datasets/delete/route.tsx b/ai-toolkit/ui/src/app/api/datasets/delete/route.tsx new file mode 100644 index 0000000000000000000000000000000000000000..3ec3e4263a82c145f022eab029fdd14aae5f522c --- /dev/null +++ b/ai-toolkit/ui/src/app/api/datasets/delete/route.tsx @@ -0,0 +1,24 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { name } = body; + let datasetsPath = await getDatasetsRoot(); + let datasetPath = path.join(datasetsPath, name); + + // if folder doesnt exist, ignore + if (!fs.existsSync(datasetPath)) { + return NextResponse.json({ success: true }); + } + + // delete it and return success + fs.rmSync(datasetPath, { recursive: true, force: true }); + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to delete dataset' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/datasets/list/route.ts b/ai-toolkit/ui/src/app/api/datasets/list/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..dc829c65f3cab2829221f85341967fc1b52a921c --- /dev/null +++ b/ai-toolkit/ui/src/app/api/datasets/list/route.ts @@ -0,0 +1,25 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function GET() { + try { + let datasetsPath = await getDatasetsRoot(); + + // if folder doesnt exist, create it + if (!fs.existsSync(datasetsPath)) { + fs.mkdirSync(datasetsPath); + } + + // find all the folders in the datasets folder + let folders = fs + .readdirSync(datasetsPath, { withFileTypes: true }) + .filter(dirent => dirent.isDirectory()) + .filter(dirent => !dirent.name.startsWith('.')) + .map(dirent => dirent.name); + + return NextResponse.json(folders); + } catch (error) { + return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/datasets/listImages/route.ts b/ai-toolkit/ui/src/app/api/datasets/listImages/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..92e666a8adbddfdb26ad01700c26d62d91a3b883 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/datasets/listImages/route.ts @@ -0,0 +1,65 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + const datasetsPath = await getDatasetsRoot(); + const body = await request.json(); + const { datasetName } = body; + const datasetFolder = path.join(datasetsPath, datasetName); + + try { + // Check if folder exists + if (!fs.existsSync(datasetFolder)) { + return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 }); + } + + // Find all images recursively + const imageFiles = findImagesRecursively(datasetFolder); + + // Sort server-side so the client doesn't have to sort large lists + imageFiles.sort((a, b) => a.localeCompare(b)); + + // Format response + const result = imageFiles.map(imgPath => ({ + img_path: imgPath, + })); + + return NextResponse.json({ images: result }); + } catch (error) { + console.error('Error finding images:', error); + return NextResponse.json({ error: 'Failed to process request' }, { status: 500 }); + } +} + +/** + * Recursively finds all image files in a directory and its subdirectories + * @param dir Directory to search + * @returns Array of absolute paths to image files + */ +function findImagesRecursively(dir: string): string[] { + const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp', '.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv', '.mp3', '.wav', '.flac', '.ogg']; + let results: string[] = []; + + // withFileTypes avoids a separate statSync per entry — a big win on large datasets + const entries = fs.readdirSync(dir, { withFileTypes: true }); + + for (const entry of entries) { + const name = entry.name; + if (name.startsWith('.')) continue; + const itemPath = path.join(dir, name); + + if (entry.isDirectory()) { + if (name === '_controls') continue; + results = results.concat(findImagesRecursively(itemPath)); + } else if (entry.isFile()) { + const ext = path.extname(name).toLowerCase(); + if (imageExtensions.includes(ext)) { + results.push(itemPath); + } + } + } + + return results; +} diff --git a/ai-toolkit/ui/src/app/api/datasets/upload/route.ts b/ai-toolkit/ui/src/app/api/datasets/upload/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..51aff81fd3bf4b091f10a1df9f2da887910f4753 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/datasets/upload/route.ts @@ -0,0 +1,57 @@ +// src/app/api/datasets/upload/route.ts +import { NextRequest, NextResponse } from 'next/server'; +import { writeFile, mkdir } from 'fs/promises'; +import { join } from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: NextRequest) { + try { + const datasetsPath = await getDatasetsRoot(); + if (!datasetsPath) { + return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 }); + } + const formData = await request.formData(); + const files = formData.getAll('files'); + const datasetName = formData.get('datasetName') as string; + + if (!files || files.length === 0) { + return NextResponse.json({ error: 'No files provided' }, { status: 400 }); + } + + // Create upload directory if it doesn't exist + const uploadDir = join(datasetsPath, datasetName); + await mkdir(uploadDir, { recursive: true }); + + const savedFiles: string[] = []; + + // Process files sequentially to avoid overwhelming the system + for (let i = 0; i < files.length; i++) { + const file = files[i] as any; + const bytes = await file.arrayBuffer(); + const buffer = Buffer.from(bytes); + + // Clean filename and ensure it's unique + const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_'); + const filePath = join(uploadDir, fileName); + + await writeFile(filePath, buffer); + savedFiles.push(fileName); + } + + return NextResponse.json({ + message: 'Files uploaded successfully', + files: savedFiles, + }); + } catch (error) { + console.error('Upload error:', error); + return NextResponse.json({ error: 'Error uploading files' }, { status: 500 }); + } +} + +// Increase payload size limit (default is 4mb) +export const config = { + api: { + bodyParser: false, + responseLimit: '50mb', + }, +}; diff --git a/ai-toolkit/ui/src/app/api/files/[...filePath]/route.ts b/ai-toolkit/ui/src/app/api/files/[...filePath]/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..3ecaf66f0ffdeb9ca8552a0d5bd560bed719d005 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/files/[...filePath]/route.ts @@ -0,0 +1,125 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; + +export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) { + const { filePath } = await params; + try { + // Decode the path + const decodedFilePath = decodeURIComponent(filePath); + + // Get allowed directories + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const allowedDirs = [datasetRoot, trainingRoot]; + + // Security check: resolve so `..` segments collapse, then verify still under + // an allowed root. Substring `.includes('..')` false-positives on filenames + // containing `..` as text (e.g. an ellipsis in a filename). + const resolvedFilePath = path.resolve(decodedFilePath); + const isAllowed = allowedDirs.some( + allowedDir => resolvedFilePath === allowedDir || resolvedFilePath.startsWith(allowedDir + path.sep), + ); + + if (!isAllowed) { + console.warn(`Access denied: ${resolvedFilePath} not in ${allowedDirs.join(', ')}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(resolvedFilePath)) { + console.warn(`File not found: ${resolvedFilePath}`); + return new NextResponse('File not found', { status: 404 }); + } + + // Get file info + const stat = fs.statSync(resolvedFilePath); + if (!stat.isFile()) { + return new NextResponse('Not a file', { status: 400 }); + } + + // Get filename for Content-Disposition + const filename = path.basename(resolvedFilePath); + + // Determine content type + const ext = path.extname(resolvedFilePath).toLowerCase(); + const contentTypeMap: { [key: string]: string } = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.svg': 'image/svg+xml', + '.bmp': 'image/bmp', + '.safetensors': 'application/octet-stream', + '.zip': 'application/zip', + // Videos + '.mp4': 'video/mp4', + '.avi': 'video/x-msvideo', + '.mov': 'video/quicktime', + '.mkv': 'video/x-matroska', + '.wmv': 'video/x-ms-wmv', + '.m4v': 'video/x-m4v', + '.flv': 'video/x-flv', + // Audio + '.mp3': 'audio/mpeg', + '.wav': 'audio/wav', + '.flac': 'audio/flac', + '.ogg': 'audio/ogg', + }; + + const contentType = contentTypeMap[ext] || 'application/octet-stream'; + + // Get range header for partial content support + const range = request.headers.get('range'); + + // Common headers for better download handling + const commonHeaders = { + 'Content-Type': contentType, + 'Accept-Ranges': 'bytes', + 'Cache-Control': 'public, max-age=86400', + 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`, + 'X-Content-Type-Options': 'nosniff', + }; + + if (range) { + // Parse range header + const parts = range.replace(/bytes=/, '').split('-'); + const start = parseInt(parts[0], 10); + const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks + const chunkSize = end - start + 1; + + const fileStream = fs.createReadStream(resolvedFilePath, { + start, + end, + highWaterMark: 64 * 1024, // 64KB buffer + }); + + return new NextResponse(fileStream as any, { + status: 206, + headers: { + ...commonHeaders, + 'Content-Range': `bytes ${start}-${end}/${stat.size}`, + 'Content-Length': String(chunkSize), + }, + }); + } else { + // For full file download, read directly without streaming wrapper + const fileStream = fs.createReadStream(resolvedFilePath, { + highWaterMark: 64 * 1024, // 64KB buffer + }); + + return new NextResponse(fileStream as any, { + headers: { + ...commonHeaders, + 'Content-Length': String(stat.size), + }, + }); + } + } catch (error) { + console.error('Error serving file:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/files/delete/route.ts b/ai-toolkit/ui/src/app/api/files/delete/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..55a21efd95c030ab205d4a5aaa267cd2e12d1a41 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/files/delete/route.ts @@ -0,0 +1,56 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; + +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { filePath } = body; + + if (!filePath || typeof filePath !== 'string') { + return new NextResponse('filePath is required', { status: 400 }); + } + + // Decode the path + const decodedFilePath = decodeURIComponent(filePath); + + // Get allowed directories + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const allowedDirs = [datasetRoot, trainingRoot]; + + // Security check: resolve so `..` segments collapse, then verify still under + // an allowed root. Substring `.includes('..')` false-positives on filenames + // containing `..` as text (e.g. an ellipsis in a filename). + const resolvedFilePath = path.resolve(decodedFilePath); + const isAllowed = allowedDirs.some( + allowedDir => resolvedFilePath === allowedDir || resolvedFilePath.startsWith(allowedDir + path.sep), + ); + + if (!isAllowed) { + console.warn(`Access denied: ${resolvedFilePath} not in ${allowedDirs.join(', ')}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(resolvedFilePath)) { + console.warn(`File not found: ${resolvedFilePath}`); + return new NextResponse('File not found', { status: 404 }); + } + + // Get file info + const stat = fs.statSync(resolvedFilePath); + if (!stat.isFile()) { + return new NextResponse('Not a file', { status: 400 }); + } + + fs.unlinkSync(resolvedFilePath); + + return NextResponse.json({ success: true }); + } catch (error) { + console.error('Error deleting file:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/gpu/route.ts b/ai-toolkit/ui/src/app/api/gpu/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..53f06fd93a4d7497075b0e727fd3e9404ac86c5d --- /dev/null +++ b/ai-toolkit/ui/src/app/api/gpu/route.ts @@ -0,0 +1,251 @@ +import { NextResponse } from 'next/server'; +import { exec, execSync } from 'child_process'; +import { promisify } from 'util'; +import { createRequire } from 'module'; +import os from 'os'; + +const execAsync = promisify(exec); + +interface MacGpuResult { + name: string; + memUsed: number; + memTotal: number; + gpuLoad: number; + temperature: number; + fanSpeed: number; + powerDraw: number; +} + +async function getMacGpuInfo(): Promise { + try { + const memoryTotal = os.totalmem() / (1024 * 1024); + + // Get GPU name and core count from system_profiler + let gpuName = 'Apple GPU'; + try { + const spOut = execSync( + 'system_profiler SPDisplaysDataType 2>/dev/null | grep -E "Chipset Model|Total Number of Cores"', + { encoding: 'utf-8', timeout: 5000 }, + ); + const nameMatch = spOut.match(/Chipset Model:\s*(.+)/); + const coresMatch = spOut.match(/Total Number of Cores:\s*(\d+)/); + if (nameMatch) { + gpuName = nameMatch[1].trim(); + if (coresMatch) { + gpuName += ` GPU (${coresMatch[1]} cores)`; + } + } + } catch { + // fallback to generic name + } + + let temperature = 0; + let gpuLoad = 0; + let fanSpeed = 0; + let powerDraw = 0; + let memUsed = 0; + let memTotal = memoryTotal; + + try { + // Use createRequire to hide from webpack static analysis so it doesn't fail on non-mac platforms + const nativeRequire = createRequire(import.meta.url); + const ms = nativeRequire('macstats') as any; + + try { + const gpuData = ms.getGpuDataSync(); + temperature = gpuData.temperature || 0; + gpuLoad = gpuData.usage || 0; + } catch { + // ignore + } + + try { + const fanData = ms.getFanDataSync(); + const fanKeys = Object.keys(fanData); + if (fanKeys.length > 0) { + fanSpeed = fanData[fanKeys[0]].rpm || 0; + } + } catch { + // ignore + } + + try { + const powerData = ms.getPowerDataSync(); + powerDraw = powerData.gpu || 0; + } catch { + // ignore + } + + try { + const ramData = ms.getRAMUsageSync(); + memUsed = ramData.used / (1024 * 1024); + memTotal = ramData.total / (1024 * 1024); + } catch { + // ignore + } + } catch (error) { + console.warn('macstats not available:', error); + } + + return { name: gpuName, memUsed, memTotal, gpuLoad, temperature, fanSpeed, powerDraw }; + } catch { + return null; + } +} + +export async function GET() { + try { + // Get platform + const platform = os.platform(); + const isWindows = platform === 'win32'; + const isMac = platform === 'darwin'; + + if (isMac) { + const macGpu = await getMacGpuInfo(); + if (macGpu) { + return NextResponse.json({ + hasNvidiaSmi: false, + isMac: true, + gpus: [ + { + index: 0, + name: macGpu.name, + driverVersion: 'macOS', + temperature: Math.round(macGpu.temperature), + utilization: { + gpu: macGpu.gpuLoad, + memory: macGpu.memTotal > 0 ? Math.round((macGpu.memUsed / macGpu.memTotal) * 100) : 0, + }, + memory: { + total: Math.round(macGpu.memTotal), + free: Math.round(macGpu.memTotal - macGpu.memUsed), + used: Math.round(macGpu.memUsed), + }, + power: { draw: macGpu.powerDraw, limit: 0 }, + clocks: { graphics: 0, memory: 0 }, + fan: { speed: macGpu.fanSpeed }, + }, + ], + }); + } + return NextResponse.json({ + hasNvidiaSmi: false, + isMac: true, + gpus: [], + error: 'Could not read Mac GPU stats', + }); + } + + // Check if nvidia-smi is available + const hasNvidiaSmi = await checkNvidiaSmi(isWindows); + + if (!hasNvidiaSmi) { + return NextResponse.json({ + hasNvidiaSmi: false, + isMac: false, + gpus: [], + error: 'nvidia-smi not found or not accessible', + }); + } + + // Get GPU stats + const gpuStats = await getGpuStats(isWindows); + + return NextResponse.json({ + hasNvidiaSmi: true, + gpus: gpuStats, + }); + } catch (error) { + console.error('Error fetching NVIDIA GPU stats:', error); + return NextResponse.json( + { + hasNvidiaSmi: false, + isMac: false, + gpus: [], + error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`, + }, + { status: 500 }, + ); + } +} + +async function checkNvidiaSmi(isWindows: boolean): Promise { + try { + if (isWindows) { + // Check if nvidia-smi is available on Windows + // It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe + // but we'll just try to run it directly as it may be in PATH + await execAsync('nvidia-smi -L'); + } else { + // Linux/macOS check + await execAsync('which nvidia-smi'); + } + return true; + } catch (error) { + return false; + } +} + +async function getGpuStats(isWindows: boolean) { + // Command is the same for both platforms, but the path might be different + const command = + 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits'; + + // Execute command + const { stdout } = await execAsync(command, { + env: { ...process.env, CUDA_DEVICE_ORDER: 'PCI_BUS_ID' }, + }); + + // Parse CSV output + const gpus = stdout + .trim() + .split('\n') + .map(line => { + const [ + index, + name, + driverVersion, + temperature, + gpuUtil, + memoryUtil, + memoryTotal, + memoryFree, + memoryUsed, + powerDraw, + powerLimit, + clockGraphics, + clockMemory, + fanSpeed, + ] = line.split(', ').map(item => item.trim()); + + return { + index: parseInt(index), + name, + driverVersion, + temperature: parseInt(temperature), + utilization: { + gpu: parseInt(gpuUtil), + memory: parseInt(memoryUtil), + }, + memory: { + total: parseInt(memoryTotal), + free: parseInt(memoryFree), + used: parseInt(memoryUsed), + }, + power: { + draw: parseFloat(powerDraw), + limit: parseFloat(powerLimit), + }, + clocks: { + graphics: parseInt(clockGraphics), + memory: parseInt(clockMemory), + }, + fan: { + speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0 + }, + }; + }); + + return gpus; +} + diff --git a/ai-toolkit/ui/src/app/api/img/[...imagePath]/route.ts b/ai-toolkit/ui/src/app/api/img/[...imagePath]/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..e41362cf63251d15a689b141f75f404960d7d437 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/img/[...imagePath]/route.ts @@ -0,0 +1,140 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { Readable } from 'stream'; +import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings'; + +const contentTypeMap: { [key: string]: string } = { + // Images + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.svg': 'image/svg+xml', + '.bmp': 'image/bmp', + // Videos + '.mp4': 'video/mp4', + '.avi': 'video/x-msvideo', + '.mov': 'video/quicktime', + '.mkv': 'video/x-matroska', + '.wmv': 'video/x-ms-wmv', + '.m4v': 'video/x-m4v', + '.flv': 'video/x-flv', + // Audio + '.mp3': 'audio/mpeg', + '.wav': 'audio/wav', + '.flac': 'audio/flac', + '.ogg': 'audio/ogg', +}; + +export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) { + const { imagePath } = await params; + try { + // Decode the path + const filepath = decodeURIComponent(imagePath); + + // Get allowed directories + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const dataRoot = await getDataRoot(); + + const allowedDirs = [datasetRoot, trainingRoot, dataRoot]; + + // Security check: resolve the path so any `..` segments are collapsed, + // then ensure it's still under an allowed root. (Plain `.includes('..')` + // false-positives on filenames that contain `..` as text, e.g. an ellipsis.) + const resolved = path.resolve(filepath); + const isAllowed = allowedDirs.some( + allowedDir => resolved === allowedDir || resolved.startsWith(allowedDir + path.sep), + ); + + if (!isAllowed) { + console.warn(`Access denied: ${resolved} not in ${allowedDirs.join(', ')}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Bail out early if the client already gave up + if (request.signal.aborted) { + return new NextResponse(null, { status: 499 }); + } + + // Stat file (async) + const stat = await fs.promises.stat(resolved).catch(() => null); + if (!stat || !stat.isFile()) { + return new NextResponse('File not found', { status: 404 }); + } + + const ext = path.extname(resolved).toLowerCase(); + const contentType = contentTypeMap[ext] || 'application/octet-stream'; + + // Weak ETag from inode/size/mtime — cheap and stable enough for revalidation + const etag = `W/"${stat.ino.toString(36)}-${stat.size.toString(36)}-${stat.mtimeMs.toString(36)}"`; + const cacheControl = 'public, max-age=86400, immutable'; + + const ifNoneMatch = request.headers.get('if-none-match'); + if (ifNoneMatch && ifNoneMatch === etag) { + return new NextResponse(null, { + status: 304, + headers: { + ETag: etag, + 'Cache-Control': cacheControl, + }, + }); + } + + const buildBody = (start?: number, end?: number) => { + const nodeStream = + start !== undefined && end !== undefined + ? fs.createReadStream(resolved, { start, end }) + : fs.createReadStream(resolved); + + // Wire client disconnect → destroy the file stream so we don't keep + // reading bytes for a request the browser has already cancelled. + const onAbort = () => nodeStream.destroy(); + if (request.signal.aborted) { + nodeStream.destroy(); + } else { + request.signal.addEventListener('abort', onAbort, { once: true }); + } + nodeStream.once('close', () => request.signal.removeEventListener('abort', onAbort)); + + return Readable.toWeb(nodeStream) as unknown as ReadableStream; + }; + + // Support range requests for video/audio seeking + const rangeHeader = request.headers.get('range'); + if (rangeHeader) { + const parts = rangeHeader.replace(/bytes=/, '').split('-'); + const start = parseInt(parts[0], 10); + const end = parts[1] ? parseInt(parts[1], 10) : stat.size - 1; + const chunkSize = end - start + 1; + + return new NextResponse(buildBody(start, end) as any, { + status: 206, + headers: { + 'Content-Range': `bytes ${start}-${end}/${stat.size}`, + 'Accept-Ranges': 'bytes', + 'Content-Length': String(chunkSize), + 'Content-Type': contentType, + 'Cache-Control': cacheControl, + ETag: etag, + }, + }); + } + + return new NextResponse(buildBody() as any, { + headers: { + 'Content-Type': contentType, + 'Content-Length': String(stat.size), + 'Cache-Control': cacheControl, + 'Accept-Ranges': 'bytes', + ETag: etag, + }, + }); + } catch (error) { + console.error('Error serving image:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/img/caption/route.ts b/ai-toolkit/ui/src/app/api/img/caption/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..7aa82e9504243305225e5cc6f983dcabade14428 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/img/caption/route.ts @@ -0,0 +1,30 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { imgPath, caption, ext } = body; + let datasetsPath = await getDatasetsRoot(); + // make sure the dataset path is in the image path + if (!imgPath.startsWith(datasetsPath)) { + return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); + } + + // if img doesnt exist, ignore + if (!fs.existsSync(imgPath)) { + return NextResponse.json({ error: 'Image does not exist' }, { status: 404 }); + } + + // check for caption (default extension txt) + const captionExt = ((ext || 'txt') as string).replace(/^\.+/, '').trim() || 'txt'; + const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.' + captionExt; + // save caption to file + fs.writeFileSync(captionPath, caption); + + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/img/delete/route.ts b/ai-toolkit/ui/src/app/api/img/delete/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..b6f56d69ca0d7facddc90370ee4b1f0994bd2d56 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/img/delete/route.ts @@ -0,0 +1,41 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { imgPath } = body; + let datasetsPath = await getDatasetsRoot(); + const trainingPath = await getTrainingFolder(); + + // make sure the dataset path is in the image path + if (!imgPath.startsWith(datasetsPath) && !imgPath.startsWith(trainingPath)) { + return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); + } + + // make sure it is an image + if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp|mp4|mp3|wav|flac|ogg)$/i.test(imgPath.toLowerCase())) { + return NextResponse.json({ error: 'Not an image' }, { status: 400 }); + } + + // if img doesnt exist, ignore + if (!fs.existsSync(imgPath)) { + return NextResponse.json({ success: true }); + } + + // delete it and return success + fs.unlinkSync(imgPath); + + // check for caption + const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt'; + if (fs.existsSync(captionPath)) { + // delete caption file + fs.unlinkSync(captionPath); + } + + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/img/upload/route.ts b/ai-toolkit/ui/src/app/api/img/upload/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..56615bd06c4bfee9e7aef4b81a620d4c8c7cbcb7 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/img/upload/route.ts @@ -0,0 +1,58 @@ +// src/app/api/datasets/upload/route.ts +import { NextRequest, NextResponse } from 'next/server'; +import { writeFile, mkdir } from 'fs/promises'; +import { join } from 'path'; +import { getDataRoot } from '@/server/settings'; +import {v4 as uuidv4} from 'uuid'; + +export async function POST(request: NextRequest) { + try { + const dataRoot = await getDataRoot(); + if (!dataRoot) { + return NextResponse.json({ error: 'Data root path not found' }, { status: 500 }); + } + const imgRoot = join(dataRoot, 'images'); + + + const formData = await request.formData(); + const files = formData.getAll('files'); + + if (!files || files.length === 0) { + return NextResponse.json({ error: 'No files provided' }, { status: 400 }); + } + + // make it recursive if it doesn't exist + await mkdir(imgRoot, { recursive: true }); + const savedFiles = await Promise.all( + files.map(async (file: any) => { + const bytes = await file.arrayBuffer(); + const buffer = Buffer.from(bytes); + + const extension = file.name.split('.').pop() || 'jpg'; + + // Clean filename and ensure it's unique + const fileName = `${uuidv4()}`; // Use UUID for unique file names + const filePath = join(imgRoot, `${fileName}.${extension}`); + + await writeFile(filePath, buffer); + return filePath; + }), + ); + + return NextResponse.json({ + message: 'Files uploaded successfully', + files: savedFiles, + }); + } catch (error) { + console.error('Upload error:', error); + return NextResponse.json({ error: 'Error uploading files' }, { status: 500 }); + } +} + +// Increase payload size limit (default is 4mb) +export const config = { + api: { + bodyParser: false, + responseLimit: '50mb', + }, +}; diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/delete/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/delete/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..626c0053bcce9654f21451169083406397760157 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/delete/route.ts @@ -0,0 +1,32 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import { getTrainingFolder } from '@/server/settings'; +import path from 'path'; +import fs from 'fs'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingRoot = await getTrainingFolder(); + const trainingFolder = path.join(trainingRoot, job.name); + + if (fs.existsSync(trainingFolder)) { + fs.rmSync(trainingFolder, { recursive: true, force: true }); + } + + await prisma.job.delete({ + where: { id: jobID }, + }); + + return NextResponse.json(job); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/files/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/files/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..bc4cc69caf7e3876fdc4fcdb8a366fd15825bc7a --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/files/route.ts @@ -0,0 +1,58 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + + if (!fs.existsSync(jobFolder)) { + return NextResponse.json({ files: [] }); + } + + // find all safetensors files in the job folder + let files = fs + .readdirSync(jobFolder) + .filter(file => { + return file.endsWith('.safetensors'); + }) + .map(file => { + return path.join(jobFolder, file); + }) + .sort(); + + // get the file size for each file + const fileObjects = files.map(file => { + const stats = fs.statSync(file); + return { + path: file, + size: stats.size, + }; + }); + + // include the optimizer state if it exists + const optimizerPath = path.join(jobFolder, 'optimizer.pt'); + if (fs.existsSync(optimizerPath)) { + const stats = fs.statSync(optimizerPath); + fileObjects.push({ + path: optimizerPath, + size: stats.size, + }); + } + + return NextResponse.json({ files: fileObjects }); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/log/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/log/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..10ccbdaac76b76ec20cead8e7f634af0d723ad8f --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/log/route.ts @@ -0,0 +1,35 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + const logPath = path.join(jobFolder, 'log.txt'); + + if (!fs.existsSync(logPath)) { + return NextResponse.json({ log: '' }); + } + let log = ''; + try { + log = fs.readFileSync(logPath, 'utf-8'); + } catch (error) { + console.error('Error reading log file:', error); + log = 'Error reading log file'; + } + return NextResponse.json({ log: log }); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/loss/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/loss/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..3aaeb50e775e78df0a9ad1f68d32d78596d9ecd9 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/loss/route.ts @@ -0,0 +1,98 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +import sqlite3 from 'sqlite3'; + +export const runtime = 'nodejs'; + +const prisma = new PrismaClient(); + +function openDb(filename: string) { + const db = new sqlite3.Database(filename); + db.configure('busyTimeout', 30_000); + return db; +} + +function all(db: sqlite3.Database, sql: string, params: any[] = []) { + return new Promise((resolve, reject) => { + db.all(sql, params, (err, rows) => { + if (err) reject(err); + else resolve(rows as T[]); + }); + }); +} + +function closeDb(db: sqlite3.Database) { + return new Promise((resolve, reject) => { + db.close((err) => (err ? reject(err) : resolve())); + }); +} + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + // this must be awaited to avoid TS error + const { jobID } = await params; + + const job = await prisma.job.findUnique({ where: { id: jobID } }); + if (!job) return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + const logPath = path.join(jobFolder, 'loss_log.db'); + + if (!fs.existsSync(logPath)) { + return NextResponse.json({ keys: [], key: 'loss', points: [] }); + } + + const url = new URL(request.url); + const key = url.searchParams.get('key') ?? 'loss'; + const limit = Math.min(Number(url.searchParams.get('limit') ?? 2000), 20000); + const sinceStepParam = url.searchParams.get('since_step'); + const sinceStep = sinceStepParam != null ? Number(sinceStepParam) : null; + const stride = Math.max(1, Number(url.searchParams.get('stride') ?? 1)); + + const db = openDb(logPath); + + try { + const keysRows = await all<{ key: string }>(db, `SELECT key FROM metric_keys ORDER BY key ASC`); + const keys = keysRows.map((r) => r.key); + + const points = await all<{ + step: number; + wall_time: number; + value: number | null; + value_text: string | null; + }>( + db, + ` + SELECT + m.step AS step, + s.wall_time AS wall_time, + m.value_real AS value, + m.value_text AS value_text + FROM metrics m + JOIN steps s ON s.step = m.step + WHERE m.key = ? + AND (? IS NULL OR m.step > ?) + AND (m.step % ?) = 0 + ORDER BY m.step ASC + LIMIT ? + `, + [key, sinceStep, sinceStep, stride, limit] + ); + + return NextResponse.json({ + key, + keys, + points: points.map((p) => ({ + step: p.step, + wall_time: p.wall_time, + value: p.value ?? (p.value_text ? Number(p.value_text) : null), + })), + }); + } finally { + await closeDb(db); + } +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/mark_stopped/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/mark_stopped/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..39cd353435e08ff66b9d5ac70522c3755284c880 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/mark_stopped/route.ts @@ -0,0 +1,27 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + // update job status to 'running' + await prisma.job.update({ + where: { id: jobID }, + data: { + stop: true, + status: 'stopped', + info: 'Job stopped', + pid: null, + }, + }); + + console.log(`Job ${jobID} marked as stopped`); + + return NextResponse.json(job); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/plugin/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/plugin/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..83f09e5d0323591533a0fbe8a144d5e203800b84 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/plugin/route.ts @@ -0,0 +1,47 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + const pluginPath = path.join(jobFolder, 'plugin.html'); + + if (!fs.existsSync(pluginPath)) { + return NextResponse.json({ exists: false, html: null }); + } + + // lightweight existence check used to decide if the Plugin tab should show + if (request.nextUrl.searchParams.get('check') === '1') { + return NextResponse.json({ exists: true, html: null }); + } + + // serve the raw html so it can be loaded directly as an iframe src + let html = ''; + try { + html = fs.readFileSync(pluginPath, 'utf-8'); + } catch (error) { + console.error('Error reading plugin file:', error); + return NextResponse.json({ error: 'Error reading plugin file' }, { status: 500 }); + } + return new NextResponse(html, { + headers: { + 'Content-Type': 'text/html; charset=utf-8', + 'Cache-Control': 'no-store', + }, + }); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/samples/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/samples/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..05918bd640de7e426d8d4628ca72148ee85ccae0 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/samples/route.ts @@ -0,0 +1,40 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + // setup the training + const trainingFolder = await getTrainingFolder(); + + const samplesFolder = path.join(trainingFolder, job.name, 'samples'); + if (!fs.existsSync(samplesFolder)) { + return NextResponse.json({ samples: [] }); + } + + // find all img (png, jpg, jpeg) files in the samples folder + const samples = fs + .readdirSync(samplesFolder) + .filter(file => { + return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp') || file.endsWith('.mp4') || file.endsWith('mp3') || file.endsWith('wav') || file.endsWith('flac') || file.endsWith('ogg'); + }) + .map(file => { + return path.join(samplesFolder, file); + }) + .sort(); + + return NextResponse.json({ samples }); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/save_now/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/save_now/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..0a6e54bd126a9117149705418f7d5dc082eadb36 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/save_now/route.ts @@ -0,0 +1,19 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.update({ + where: { id: jobID }, + data: { + save_now: true, + }, + }); + + console.log(`Job ${jobID} marked to save on next step`); + + return NextResponse.json(job); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/start/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/start/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..0417dcff74ef9539b44bd37817b8e8a9dedeb1c6 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/start/route.ts @@ -0,0 +1,59 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + // get highest queue position + const highestQueuePosition = await prisma.job.aggregate({ + _max: { + queue_position: true, + }, + }); + const newQueuePosition = (highestQueuePosition._max.queue_position || 0) + 1000; + + await prisma.job.update({ + where: { id: jobID }, + data: { queue_position: newQueuePosition }, + }); + + // make sure the queue is running + const queue = await prisma.queue.findFirst({ + where: { + gpu_ids: job.gpu_ids, + }, + }); + + // if queue doesn't exist, create it + if (!queue) { + await prisma.queue.create({ + data: { + gpu_ids: job.gpu_ids, + is_running: false, + }, + }); + } + + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'queued', + stop: false, + return_to_queue: false, + info: 'Job queued', + }, + }); + + // Return the response immediately + return NextResponse.json(job); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/[jobID]/stop/route.ts b/ai-toolkit/ui/src/app/api/jobs/[jobID]/stop/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..417a559749f2848fe3fdd450cffd358859f15a4c --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/[jobID]/stop/route.ts @@ -0,0 +1,55 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); +const isWindows = process.platform === 'win32'; + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + await prisma.job.update({ + where: { id: jobID }, + data: { + stop: true, + info: 'Stopping job...', + }, + }); + + // Send SIGINT to the process if we have a PID + if (job.pid != null) { + console.log(`Attempting to stop job ${jobID} with PID ${job.pid}`); + try { + if (isWindows) { + // Windows doesn't support SIGINT for arbitrary processes. + // Use taskkill with /T (tree) to send a CTRL+C-like termination. + const { execSync } = require('child_process'); + execSync(`taskkill /PID ${job.pid} /T /F`, { stdio: 'ignore' }); + } else { + process.kill(job.pid, 'SIGINT'); + } + // if it killed it, mark it stopped in the database + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'stopped', + info: 'Job stopped', + }, + }); + } catch (e) { + // Process may have already exited — that's fine + console.error('Error sending signal to process:', e); + } + } else { + console.warn(`No PID found for job ${jobID}, cannot send stop signal`); + } + + return NextResponse.json(job); +} diff --git a/ai-toolkit/ui/src/app/api/jobs/route.ts b/ai-toolkit/ui/src/app/api/jobs/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..11dd95271418034f70577be6a0e97d0f8aa56e8c --- /dev/null +++ b/ai-toolkit/ui/src/app/api/jobs/route.ts @@ -0,0 +1,100 @@ +import { NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import { isMac } from '@/helpers/basic'; + +const prisma = new PrismaClient(); + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const id = searchParams.get('id'); + const job_ref = searchParams.get('job_ref'); + const job_type = searchParams.get('job_type'); + + try { + if (id) { + const job = await prisma.job.findUnique({ + where: { id }, + }); + return NextResponse.json(job); + } + if (job_ref) { + const job = await prisma.job.findFirst({ + where: { job_ref }, + orderBy: { updated_at: 'desc' }, + }); + return NextResponse.json(job); + } + + const jobs = await prisma.job.findMany({ + where: job_type ? { job_type } : undefined, + orderBy: { created_at: 'desc' }, + }); + return NextResponse.json({ jobs: jobs }); + } catch (error) { + console.error(error); + return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 }); + } +} + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { id, name, job_config } = body; + let gpu_ids: string = body.gpu_ids; + + if (isMac()) { + gpu_ids = "mps"; + } + + const extra: any = {}; + if ("job_ref" in body) { + extra["job_ref"] = body.job_ref; + } + + if ("job_type" in body) { + extra["job_type"] = body.job_type; + } + + if (id) { + // Update existing training + const training = await prisma.job.update({ + where: { id }, + data: { + name, + gpu_ids, + job_config: JSON.stringify(job_config), + ...extra, + }, + }); + return NextResponse.json(training); + } else { + // find the highest queue position and add 1000 + const highestQueuePosition = await prisma.job.aggregate({ + _max: { + queue_position: true, + }, + }); + const newQueuePosition = (highestQueuePosition._max.queue_position || 0) + 1000; + + // Create new training + const training = await prisma.job.create({ + data: { + name, + gpu_ids, + job_config: JSON.stringify(job_config), + queue_position: newQueuePosition, + ...extra, + }, + }); + return NextResponse.json(training); + } + } catch (error: any) { + if (error.code === 'P2002') { + // Handle unique constraint violation, 409=Conflict + return NextResponse.json({ error: 'Job name already exists' }, { status: 409 }); + } + console.error(error); + // Handle other errors + return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/ostris_cloud/route.ts b/ai-toolkit/ui/src/app/api/ostris_cloud/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..b883928547f3b4feaedf39d4e892c31156c5e970 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/ostris_cloud/route.ts @@ -0,0 +1,38 @@ +import { NextResponse } from 'next/server'; + +export async function GET() { + const appUrl = process.env.OSTRIS_CLOUD_APP_URL; + const apiKey = process.env.OSTRIS_CLOUD_API_KEY; + + if (!appUrl || !apiKey) { + return NextResponse.json({ enabled: false }); + } + + try { + const res = await fetch(`${appUrl}/api/machine/me`, { + headers: { Authorization: `Bearer ${apiKey}` }, + cache: 'no-store', + }); + + if (!res.ok) { + return NextResponse.json({ + enabled: true, + appUrl, + error: `Ostris Cloud responded with ${res.status}`, + }); + } + + const data = await res.json(); + return NextResponse.json({ + enabled: true, + appUrl, + balance: data.balance ?? null, + }); + } catch (error) { + return NextResponse.json({ + enabled: true, + appUrl, + error: `Failed to fetch Ostris Cloud balance: ${error instanceof Error ? error.message : String(error)}`, + }); + } +} diff --git a/ai-toolkit/ui/src/app/api/queue/[queueID]/start/route.ts b/ai-toolkit/ui/src/app/api/queue/[queueID]/start/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..d67ff4ecbb7cd60cefdc6045f4b19cc590973d82 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/queue/[queueID]/start/route.ts @@ -0,0 +1,27 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { queueID: string } }) { + const { queueID } = await params; + + const queue = await prisma.queue.findUnique({ + where: { gpu_ids: queueID }, + }); + + if (!queue) { + // create it if it doesn't exist + const newQueue = await prisma.queue.create({ + data: { gpu_ids: queueID, is_running: true }, + }); + return NextResponse.json(newQueue); + } + + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: true }, + }); + + return NextResponse.json(queue); +} diff --git a/ai-toolkit/ui/src/app/api/queue/[queueID]/stop/route.ts b/ai-toolkit/ui/src/app/api/queue/[queueID]/stop/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..87e608b108f3bb71e5b7e04b02ff7c17581ec4a4 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/queue/[queueID]/stop/route.ts @@ -0,0 +1,23 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { queueID: string } }) { + const { queueID } = await params; + + const queue = await prisma.queue.findUnique({ + where: { gpu_ids: queueID }, + }); + + if (!queue) { + return NextResponse.json({ error: 'Queue not found' }, { status: 404 }); + } + + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: false }, + }); + + return NextResponse.json(queue); +} diff --git a/ai-toolkit/ui/src/app/api/queue/route.ts b/ai-toolkit/ui/src/app/api/queue/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..08c6d4bcd8d93851d1b7954db27699ec0c11293b --- /dev/null +++ b/ai-toolkit/ui/src/app/api/queue/route.ts @@ -0,0 +1,18 @@ +import { NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + + try { + const queues = await prisma.queue.findMany({ + orderBy: { gpu_ids: 'asc' }, + }); + return NextResponse.json({ queues: queues }); + } catch (error) { + console.error(error); + return NextResponse.json({ error: 'Failed to fetch queue' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/scripts/route.ts b/ai-toolkit/ui/src/app/api/scripts/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..bab8af491ee4111c4aeb5938acaff1c491356afb --- /dev/null +++ b/ai-toolkit/ui/src/app/api/scripts/route.ts @@ -0,0 +1,248 @@ +import { NextResponse } from 'next/server'; +import { spawn } from 'child_process'; +import path from 'path'; +import fs from 'fs'; +import { TOOLKIT_ROOT } from '@/paths'; +import { resolvePythonPath } from '../../../../cron/pythonPath'; + +// Long-running scripts: allow up to 20 minutes. +export const runtime = 'nodejs'; +export const dynamic = 'force-dynamic'; +export const maxDuration = 1200; + +const TIMEOUT_MS = 20 * 60 * 1000; +const UI_SCRIPTS_ROOT = path.join(TOOLKIT_ROOT, 'ui_scripts'); +// Only allow flat script names (no path separators, no traversal). +const SCRIPT_NAME_RE = /^[A-Za-z0-9_][A-Za-z0-9_.-]*\.py$/; + +const resolveScriptPath = (rawName: unknown): string | null => { + if (typeof rawName !== 'string') return null; + const name = rawName.trim(); + if (!SCRIPT_NAME_RE.test(name)) return null; + + const target = path.resolve(UI_SCRIPTS_ROOT, name); + const rootWithSep = UI_SCRIPTS_ROOT.endsWith(path.sep) ? UI_SCRIPTS_ROOT : UI_SCRIPTS_ROOT + path.sep; + if (!target.startsWith(rootWithSep)) return null; + if (!fs.existsSync(target) || !fs.statSync(target).isFile()) return null; + return target; +}; + +// Args may be a positional list or an object that becomes --key value pairs. +// Every value is stringified before being passed to spawn (no shell). +const normalizeArgs = (raw: unknown): string[] | { error: string } => { + if (raw == null) return []; + if (Array.isArray(raw)) { + const out: string[] = []; + for (const v of raw) { + if (v == null) continue; + if (typeof v === 'string' || typeof v === 'number' || typeof v === 'boolean') { + out.push(String(v)); + } else { + return { error: 'args entries must be string|number|boolean' }; + } + } + return out; + } + if (typeof raw === 'object') { + const out: string[] = []; + for (const [key, value] of Object.entries(raw as Record)) { + if (!/^[A-Za-z0-9_-]+$/.test(key)) return { error: `invalid arg key: ${key}` }; + const flag = `--${key}`; + if (value === true) { + out.push(flag); + } else if (value === false || value == null) { + continue; + } else if (typeof value === 'string' || typeof value === 'number') { + out.push(flag, String(value)); + } else { + return { error: `args.${key} must be string|number|boolean` }; + } + } + return out; + } + return { error: 'args must be an array or object' }; +}; + +interface RunResult { + ok: boolean; + exitCode: number | null; + signal: NodeJS.Signals | null; + stdout: string; + stderr: string; + result: unknown; + timedOut: boolean; + error?: string; +} + +// Parses the last line of stdout as JSON if possible — scripts can use this +// to return structured data alongside their human-readable logs. +const parseResult = (stdout: string): unknown => { + const lines = stdout.trimEnd().split(/\r?\n/); + for (let i = lines.length - 1; i >= 0; i--) { + const line = lines[i].trim(); + if (!line) continue; + if (line.startsWith('{') || line.startsWith('[')) { + try { + return JSON.parse(line); + } catch { + return null; + } + } + return null; + } + return null; +}; + +const runBuffered = (scriptPath: string, args: string[]): Promise => { + return new Promise(resolve => { + const child = spawn(resolvePythonPath(), ['-u', scriptPath, ...args], { + cwd: TOOLKIT_ROOT, + env: { ...process.env, PYTHONUNBUFFERED: '1', PYTHONIOENCODING: 'utf-8' }, + windowsHide: true, + }); + + let stdout = ''; + let stderr = ''; + let timedOut = false; + + const timer = setTimeout(() => { + timedOut = true; + child.kill('SIGKILL'); + }, TIMEOUT_MS); + + child.stdout.on('data', (chunk: Buffer) => { + stdout += chunk.toString('utf-8'); + }); + child.stderr.on('data', (chunk: Buffer) => { + stderr += chunk.toString('utf-8'); + }); + + child.on('error', err => { + clearTimeout(timer); + resolve({ + ok: false, + exitCode: null, + signal: null, + stdout, + stderr, + result: null, + timedOut, + error: err.message, + }); + }); + + child.on('close', (code, signal) => { + clearTimeout(timer); + resolve({ + ok: !timedOut && code === 0, + exitCode: code, + signal, + stdout, + stderr, + result: parseResult(stdout), + timedOut, + error: timedOut ? 'Script timed out after 20 minutes' : undefined, + }); + }); + }); +}; + +// NDJSON stream: one JSON object per line so clients can parse incrementally. +const runStreaming = (scriptPath: string, args: string[]): Response => { + const child = spawn(resolvePythonPath(), ['-u', scriptPath, ...args], { + cwd: TOOLKIT_ROOT, + env: { ...process.env, PYTHONUNBUFFERED: '1', PYTHONIOENCODING: 'utf-8' }, + windowsHide: true, + }); + + const encoder = new TextEncoder(); + let stdoutBuf = ''; + let stderrBuf = ''; + let timedOut = false; + + const stream = new ReadableStream({ + start(controller) { + const send = (obj: unknown) => { + controller.enqueue(encoder.encode(JSON.stringify(obj) + '\n')); + }; + + const timer = setTimeout(() => { + timedOut = true; + send({ type: 'error', message: 'Script timed out after 20 minutes' }); + child.kill('SIGKILL'); + }, TIMEOUT_MS); + + child.stdout.on('data', (chunk: Buffer) => { + const text = chunk.toString('utf-8'); + stdoutBuf += text; + send({ type: 'stdout', data: text }); + }); + child.stderr.on('data', (chunk: Buffer) => { + const text = chunk.toString('utf-8'); + stderrBuf += text; + send({ type: 'stderr', data: text }); + }); + + child.on('error', err => { + clearTimeout(timer); + send({ type: 'error', message: err.message }); + controller.close(); + }); + + child.on('close', (code, signal) => { + clearTimeout(timer); + send({ + type: 'exit', + exitCode: code, + signal, + ok: !timedOut && code === 0, + timedOut, + result: parseResult(stdoutBuf), + stderr: stderrBuf, + }); + controller.close(); + }); + }, + cancel() { + if (!child.killed) child.kill('SIGKILL'); + }, + }); + + return new Response(stream, { + headers: { + 'Content-Type': 'application/x-ndjson; charset=utf-8', + 'Cache-Control': 'no-cache, no-transform', + 'X-Accel-Buffering': 'no', + }, + }); +}; + +export async function POST(request: Request) { + let body: any; + try { + body = await request.json(); + } catch { + return NextResponse.json({ error: 'Invalid JSON body' }, { status: 400 }); + } + + const scriptPath = resolveScriptPath(body?.script); + if (!scriptPath) { + return NextResponse.json( + { error: 'Invalid or unknown script. Must be a *.py file inside ui_scripts/.' }, + { status: 400 }, + ); + } + + const normalized = normalizeArgs(body?.args); + if (!Array.isArray(normalized)) { + return NextResponse.json({ error: normalized.error }, { status: 400 }); + } + + if (body?.stream === true) { + return runStreaming(scriptPath, normalized); + } + + const result = await runBuffered(scriptPath, normalized); + const status = result.ok ? 200 : result.timedOut ? 504 : 500; + return NextResponse.json(result, { status }); +} diff --git a/ai-toolkit/ui/src/app/api/settings/route.ts b/ai-toolkit/ui/src/app/api/settings/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..62528cdd0b6a7de39c7ade3e96ea9f0b1ec2a226 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/settings/route.ts @@ -0,0 +1,59 @@ +import { NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths'; +import { flushCache } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET() { + try { + const settings = await prisma.settings.findMany(); + const settingsObject = settings.reduce((acc: any, setting) => { + acc[setting.key] = setting.value; + return acc; + }, {}); + // if TRAINING_FOLDER is not set, use default + if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') { + settingsObject.TRAINING_FOLDER = defaultTrainFolder; + } + // if DATASETS_FOLDER is not set, use default + if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') { + settingsObject.DATASETS_FOLDER = defaultDatasetsFolder; + } + return NextResponse.json(settingsObject); + } catch (error) { + return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 }); + } +} + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body; + + // Upsert both settings + await Promise.all([ + prisma.settings.upsert({ + where: { key: 'HF_TOKEN' }, + update: { value: HF_TOKEN }, + create: { key: 'HF_TOKEN', value: HF_TOKEN }, + }), + prisma.settings.upsert({ + where: { key: 'TRAINING_FOLDER' }, + update: { value: TRAINING_FOLDER }, + create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER }, + }), + prisma.settings.upsert({ + where: { key: 'DATASETS_FOLDER' }, + update: { value: DATASETS_FOLDER }, + create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER }, + }), + ]); + + flushCache(); + + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/api/zip/route.ts b/ai-toolkit/ui/src/app/api/zip/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..fc4b946da5f6265d4d193849bf218fea41ea6e01 --- /dev/null +++ b/ai-toolkit/ui/src/app/api/zip/route.ts @@ -0,0 +1,78 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import fsp from 'fs/promises'; +import path from 'path'; +import archiver from 'archiver'; +import { getTrainingFolder } from '@/server/settings'; + +export const runtime = 'nodejs'; // ensure Node APIs are available +export const dynamic = 'force-dynamic'; // long-running, non-cached + +type PostBody = { + zipTarget: 'samples'; //only samples for now + jobName: string; +}; + +async function resolveSafe(p: string) { + // resolve symlinks + normalize + return await fsp.realpath(p); +} + +export async function POST(request: NextRequest) { + try { + const body = (await request.json()) as PostBody; + if (!body || !body.jobName) { + return NextResponse.json({ error: 'jobName is required' }, { status: 400 }); + } + + const trainingRoot = await resolveSafe(await getTrainingFolder()); + const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples')); + const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip'); + + // Must be a directory + let stat: fs.Stats; + try { + stat = await fsp.stat(folderPath); + } catch { + return new NextResponse('Folder not found', { status: 404 }); + } + if (!stat.isDirectory()) { + return new NextResponse('Not a directory', { status: 400 }); + } + + // delete current one if it exists + if (fs.existsSync(outputPath)) { + await fsp.unlink(outputPath); + } + + // Create write stream & archive + await new Promise((resolve, reject) => { + const output = fs.createWriteStream(outputPath); + const archive = archiver('zip', { zlib: { level: 9 } }); + + output.on('close', () => resolve()); + output.on('error', reject); + archive.on('error', reject); + + archive.pipe(output); + + // Add the directory contents (place them under the folder's base name in the zip) + const rootName = path.basename(folderPath); + archive.directory(folderPath, rootName); + + archive.finalize().catch(reject); + }); + + // Return the absolute path so your existing /api/files/[...filePath] can serve it + // Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}` + return NextResponse.json({ + ok: true, + zipPath: outputPath, + fileName: path.basename(outputPath), + }); + } catch (err) { + console.error('Zip error:', err); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/ai-toolkit/ui/src/app/dashboard/page.tsx b/ai-toolkit/ui/src/app/dashboard/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..8685a10641c80d042f62b42f64a5bd728cae06a6 --- /dev/null +++ b/ai-toolkit/ui/src/app/dashboard/page.tsx @@ -0,0 +1,31 @@ +'use client'; + +import GpuMonitor from '@/components/GPUMonitor'; +import JobsTable from '@/components/JobsTable'; +import { TopBar, MainContent } from '@/components/layout'; +import Link from 'next/link'; + +export default function Dashboard() { + return ( + <> + +
+

Dashboard

+
+
+
+ + +
+
+

Queues

+
+ View All +
+
+ +
+
+ + ); +} diff --git a/ai-toolkit/ui/src/app/datasets/[datasetName]/page.tsx b/ai-toolkit/ui/src/app/datasets/[datasetName]/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..147f9c182962ebb4b7823175990acdf2f2789574 --- /dev/null +++ b/ai-toolkit/ui/src/app/datasets/[datasetName]/page.tsx @@ -0,0 +1,200 @@ +'use client'; + +import { useEffect, useState, use, useMemo, useCallback } from 'react'; +import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu'; +import { FaChevronLeft } from 'react-icons/fa'; +import { VirtuosoGrid } from 'react-virtuoso'; +import DatasetImageCard from '@/components/DatasetImageCard'; +import DatasetImageViewer from '@/components/DatasetImageViewer'; +import { Button } from '@headlessui/react'; +import AddImagesModal, { openImagesModal, useOpenImagesModalOnDrag } from '@/components/AddImagesModal'; +import { TopBar, MainContent } from '@/components/layout'; +import { apiClient } from '@/utils/api'; +import useSettings from '@/hooks/useSettings'; +import { pathJoin } from '@/utils/basic'; +import AutoCaptionButton from '@/components/AutoCaptionButton'; +import CaptionMonitor from '@/components/CaptionMonitor'; +import { CreatableSelectInput } from '@/components/formInputs'; + +export default function DatasetPage({ params }: { params: { datasetName: string } }) { + const [imgList, setImgList] = useState<{ img_path: string }[]>([]); + const [isAutoCaptioning, setIsAutoCaptioning] = useState(false); + const usableParams = use(params as any) as { datasetName: string }; + const datasetName = usableParams.datasetName; + const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); + const { settings, isSettingsLoaded } = useSettings(); + const [selectedImgPath, setSelectedImgPath] = useState(null); + const [captionExt, setCaptionExt] = useState('txt'); + const [captionRefreshKeys, setCaptionRefreshKeys] = useState>({}); + const [scrollParent, setScrollParent] = useState(null); + const [captionBarHeight, setCaptionBarHeight] = useState(0); + const scrollParentCallback = useCallback((el: HTMLDivElement | null) => setScrollParent(el), []); + + const refreshImageList = (dbName: string) => { + setStatus('loading'); + apiClient + .post('/api/datasets/listImages', { datasetName: dbName }) + .then((res: any) => { + const data = res.data; + // Server already sorts; avoid the client-side sort that's expensive on large lists. + setImgList(data.images); + setStatus('success'); + }) + .catch(error => { + console.error('Error fetching images:', error); + setStatus('error'); + }); + }; + useOpenImagesModalOnDrag(datasetName, () => refreshImageList(datasetName)); + + const imgPaths = useMemo(() => imgList.map(img => img.img_path), [imgList]); + + useEffect(() => { + if (datasetName) { + refreshImageList(datasetName); + } + }, [datasetName]); + + const PageInfoContent = useMemo(() => { + let icon = null; + let text = ''; + let subtitle = ''; + let showIt = false; + let bgColor = ''; + let textColor = ''; + let iconColor = ''; + + if (status == 'loading') { + icon = ; + text = 'Loading Images'; + subtitle = 'Please wait while we fetch your dataset images...'; + showIt = true; + bgColor = 'bg-gray-800/50'; + textColor = 'text-gray-100'; + iconColor = 'text-gray-400'; + } + if (status == 'error') { + icon = ; + text = 'Error Loading Images'; + subtitle = 'There was a problem fetching the images. Please try refreshing the page.'; + showIt = true; + bgColor = 'bg-red-600/20'; + textColor = 'text-red-100'; + iconColor = 'text-red-400'; + } + if (status == 'success' && imgList.length === 0) { + icon = ; + text = 'No Images Found'; + subtitle = 'This dataset is empty. Click "Add Images" to get started.'; + showIt = true; + bgColor = 'bg-gray-800/50'; + textColor = 'text-gray-100'; + iconColor = 'text-gray-400'; + } + + if (!showIt) return null; + + return ( +
+
{icon}
+

{text}

+

{subtitle}

+
+ ); + }, [status, imgList.length]); + + return ( + <> + {/* Fixed top bar */} + +
+ +
+
+

+ Dataset: + {datasetName} +

+
+
+
+
+ + setCaptionExt(value)} + options={[ + { value: 'txt', label: 'txt' }, + { value: 'json', label: 'json' }, + { value: 'caption', label: 'caption' }, + ]} + /> +
+ + +
+
+ + {PageInfoContent} + {status === 'success' && imgList.length > 0 && scrollParent && ( + { + const img = imgList[index]; + if (!img) return null; + return ( + refreshImageList(datasetName)} + onImageClick={() => setSelectedImgPath(img.img_path)} + captionRefreshKey={captionRefreshKeys[img.img_path] || 0} + observerRoot={scrollParent} + captionExt={captionExt} + /> + ); + }} + computeItemKey={index => imgList[index]?.img_path ?? index} + /> + )} + {/* Spacer so the last cards stay accessible above the floating caption bar. + Always keeps a baseline gap, plus the bar height when it is showing. */} +
+ + + {isSettingsLoaded && ( + + )} + refreshImageList(datasetName)} + onCaptionSaved={path => setCaptionRefreshKeys(prev => ({ ...prev, [path]: (prev[path] || 0) + 1 }))} + captionExt={captionExt} + /> + + ); +} diff --git a/ai-toolkit/ui/src/app/datasets/page.tsx b/ai-toolkit/ui/src/app/datasets/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..9003dabe0e0768010a7cad511c231b224bc10a62 --- /dev/null +++ b/ai-toolkit/ui/src/app/datasets/page.tsx @@ -0,0 +1,175 @@ +'use client'; + +import { useState } from 'react'; +import { Modal } from '@/components/Modal'; +import Link from 'next/link'; +import { TextInput } from '@/components/formInputs'; +import useDatasetList from '@/hooks/useDatasetList'; +import { Button } from '@headlessui/react'; +import { FaRegTrashAlt } from 'react-icons/fa'; +import { openConfirm } from '@/components/ConfirmModal'; +import { TopBar, MainContent } from '@/components/layout'; +import UniversalTable, { TableColumn } from '@/components/UniversalTable'; +import { apiClient } from '@/utils/api'; +import { useRouter } from 'next/navigation'; + +export default function Datasets() { + const router = useRouter(); + const { datasets, status, refreshDatasets } = useDatasetList(); + const [newDatasetName, setNewDatasetName] = useState(''); + const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false); + + // Transform datasets array into rows with objects + const tableRows = datasets.map(dataset => ({ + name: dataset, + actions: dataset, // Pass full dataset name for actions + })); + + const columns: TableColumn[] = [ + { + title: 'Dataset Name', + key: 'name', + render: row => ( + + {row.name} + + ), + }, + { + title: 'Actions', + key: 'actions', + className: 'w-20 text-right', + render: row => ( + + ), + }, + ]; + + const handleDeleteDataset = (datasetName: string) => { + openConfirm({ + title: 'Delete Dataset', + message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`, + type: 'warning', + confirmText: 'Delete', + onConfirm: () => { + apiClient + .post('/api/datasets/delete', { name: datasetName }) + .then(() => { + console.log('Dataset deleted:', datasetName); + refreshDatasets(); + }) + .catch(error => { + console.error('Error deleting dataset:', error); + }); + }, + }); + }; + + const handleCreateDataset = async (e: React.FormEvent) => { + e.preventDefault(); + try { + const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data); + console.log('New dataset created:', data); + refreshDatasets(); + setNewDatasetName(''); + setIsNewDatasetModalOpen(false); + } catch (error) { + console.error('Error creating new dataset:', error); + } + }; + + const openNewDatasetModal = () => { + openConfirm({ + title: 'New Dataset', + message: 'Enter the name of the new dataset:', + type: 'info', + confirmText: 'Create', + inputTitle: 'Dataset Name', + onConfirm: async (name?: string) => { + if (!name) { + console.error('Dataset name is required.'); + return; + } + try { + const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data); + console.log('New dataset created:', data); + if (data.name) { + router.push(`/datasets/${data.name}`); + } else { + refreshDatasets(); + } + } catch (error) { + console.error('Error creating new dataset:', error); + } + }, + }); + }; + + return ( + <> + +
+

Datasets

+
+
+
+ +
+
+ + + + + + setIsNewDatasetModalOpen(false)} + title="New Dataset" + size="md" + > +
+
+
+ This will create a new folder with the name below in your dataset folder. +
+
+ setNewDatasetName(value)} /> +
+ +
+ + +
+
+
+
+ + ); +} diff --git a/ai-toolkit/ui/src/app/favicon.ico b/ai-toolkit/ui/src/app/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..a20b629a5996a0b62c038bf356f1e28eab9bdb99 Binary files /dev/null and b/ai-toolkit/ui/src/app/favicon.ico differ diff --git a/ai-toolkit/ui/src/app/globals.css b/ai-toolkit/ui/src/app/globals.css new file mode 100644 index 0000000000000000000000000000000000000000..af2cca0f2b336658885010e706e8c5aecc2ab200 --- /dev/null +++ b/ai-toolkit/ui/src/app/globals.css @@ -0,0 +1,131 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +/* Light mode (default root values, overridden by .dark) */ +:root { + --background: #fafafa; + --foreground: #171717; + + /* Gray scale: inverted for light mode so bg-gray-950 = light bg, text-gray-100 = dark text */ + --gray-950: 235 235 235; + --gray-900: 215 215 215; + --gray-800: 195 195 195; + --gray-700: 175 175 175; + --gray-600: 163 163 163; + --gray-500: 115 115 115; + --gray-400: 82 82 82; + --gray-300: 64 64 64; + --gray-200: 38 38 38; + --gray-100: 23 23 23; +} + +/* Dark mode */ +.dark { + --background: #0a0a0a; + --foreground: #ededed; + + --gray-950: 10 10 10; + --gray-900: 23 23 23; + --gray-800: 38 38 38; + --gray-700: 64 64 64; + --gray-600: 82 82 82; + --gray-500: 115 115 115; + --gray-400: 163 163 163; + --gray-300: 212 212 212; + --gray-200: 229 229 229; + --gray-100: 245 245 245; +} + +body { + color: var(--foreground); + background: var(--background); + font-family: Arial, Helvetica, sans-serif; +} + +/* Prevent iOS Safari from auto-zooming when a form field is focused. + iOS only zooms when the field's font-size is < 16px. */ +@media (max-width: 639px) { + input, + textarea, + select, + .aitk-react-select-container .aitk-react-select__control, + .aitk-react-select-container .aitk-react-select__input, + .aitk-react-select-container .aitk-react-select__single-value, + .aitk-react-select-container .aitk-react-select__placeholder, + .aitk-react-select-container .aitk-react-select__option { + font-size: 16px !important; + } +} + +@keyframes heartbeat { + 0%, + 40%, + 100% { + transform: scale(1); + } + 10% { + transform: scale(1.6); + } + 20% { + transform: scale(1); + } + 30% { + transform: scale(1.45); + } +} + +.animate-heartbeat { + animation: heartbeat 1.6s ease-in-out infinite; + transform-origin: center; + transform-box: fill-box; + will-change: transform; +} + +@layer components { + /* control */ + .aitk-react-select-container .aitk-react-select__control { + @apply flex w-full h-8 min-h-0 px-0 text-sm bg-gray-950 dark:bg-gray-800 border border-gray-700 rounded-sm hover:border-gray-600 items-center; + } + + /* selected label */ + .aitk-react-select-container .aitk-react-select__single-value { + @apply flex-1 min-w-0 truncate text-sm text-gray-200; + } + + /* invisible input (keeps focus & typing, never wraps) */ + .aitk-react-select-container .aitk-react-select__input-container { + @apply text-gray-200; + } + + /* focus */ + .aitk-react-select-container .aitk-react-select__control--is-focused { + @apply ring-2 ring-gray-600 border-transparent hover:border-transparent shadow-none; + } + + /* menu */ + .aitk-react-select-container .aitk-react-select__menu { + @apply bg-gray-950 dark:bg-gray-800 border border-gray-700; + } + + /* options */ + .aitk-react-select-container .aitk-react-select__option { + @apply text-sm text-gray-200 bg-gray-950 dark:bg-gray-800 hover:bg-gray-700 dark:hover:bg-gray-700 cursor-pointer; + } + + /* indicator separator */ + .aitk-react-select-container .aitk-react-select__indicator-separator { + @apply bg-gray-600; + } + + /* indicators */ + .aitk-react-select-container .aitk-react-select__indicators, + .aitk-react-select-container .aitk-react-select__indicator { + @apply py-0 flex items-center; + } + + /* placeholder */ + .aitk-react-select-container .aitk-react-select__placeholder { + @apply text-sm text-gray-200; + } +} diff --git a/ai-toolkit/ui/src/app/icon.svg b/ai-toolkit/ui/src/app/icon.svg new file mode 100644 index 0000000000000000000000000000000000000000..2689ae5393931a68144db7d92555343aeef0155c --- /dev/null +++ b/ai-toolkit/ui/src/app/icon.svg @@ -0,0 +1,3 @@ + \ No newline at end of file diff --git a/ai-toolkit/ui/src/app/jobs/[jobID]/page.tsx b/ai-toolkit/ui/src/app/jobs/[jobID]/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..7f5c2a8abee9c025fcd619fef6ab057745f25263 --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/[jobID]/page.tsx @@ -0,0 +1,167 @@ +'use client'; + +import { useState, useEffect, use } from 'react'; +import { FaChevronLeft } from 'react-icons/fa'; +import { MdDashboard, MdImage, MdShowChart, MdCode, MdExtension } from 'react-icons/md'; +import { Button } from '@headlessui/react'; +import { TopBar, MainContent } from '@/components/layout'; +import useJob from '@/hooks/useJob'; +import SampleImages, { SampleImagesMenu } from '@/components/SampleImages'; +import JobOverview from '@/components/JobOverview'; +import { redirect } from 'next/navigation'; +import JobActionBar from '@/components/JobActionBar'; +import JobConfigViewer from '@/components/JobConfigViewer'; +import JobLossGraph from '@/components/JobLossGraph'; +import JobPlugin from '@/components/JobPlugin'; +import { Job } from '@prisma/client'; +import { apiClient } from '@/utils/api'; + +type PageKey = 'overview' | 'samples' | 'config' | 'loss_log' | 'plugin'; + +interface Page { + name: string; + value: PageKey; + icon: React.ComponentType<{ className?: string }>; + component: React.ComponentType<{ job: Job }>; + menuItem?: React.ComponentType<{ job?: Job | null }> | null; + mainCss?: string; + jobTypes?: string[]; // if specified, only show this page for these job types +} + +const pages: Page[] = [ + { + name: 'Overview', + value: 'overview', + icon: MdDashboard, + component: JobOverview, + mainCss: 'pt-24', + }, + { + name: 'Samples', + value: 'samples', + icon: MdImage, + component: SampleImages, + menuItem: SampleImagesMenu, + mainCss: 'pt-24', + jobTypes: ['train'], + }, + { + name: 'Loss Graph', + value: 'loss_log', + icon: MdShowChart, + component: JobLossGraph, + mainCss: 'pt-24 pb-4', + jobTypes: ['train'], + }, + { + name: 'Config File', + value: 'config', + icon: MdCode, + component: JobConfigViewer, + mainCss: 'pt-[80px] px-0 pb-0', + }, + { + name: 'Plugin', + value: 'plugin', + icon: MdExtension, + component: JobPlugin, + mainCss: 'pt-[80px] px-0 pb-0', + }, +]; + +export default function JobPage({ params }: { params: { jobID: string } }) { + const usableParams = use(params as any) as { jobID: string }; + const jobID = usableParams.jobID; + const { job, status, refreshJob } = useJob(jobID, 5000); + const [pageKey, setPageKey] = useState('overview'); + const [hasPlugin, setHasPlugin] = useState(false); + + // poll for plugin.html in the job folder; show the Plugin tab if it exists + useEffect(() => { + const checkPlugin = () => { + apiClient + .get(`/api/jobs/${jobID}/plugin?check=1`) + .then(res => res.data) + .then(data => setHasPlugin(!!data.exists)) + .catch(() => {}); + }; + checkPlugin(); + const interval = setInterval(checkPlugin, 5000); + return () => clearInterval(interval); + }, [jobID]); + + const page = pages.find(p => p.value === pageKey); + + const jobType = job?.job_type || 'unknown'; + + let title = `Job: ${job?.name || 'Loading...'}`; + if (jobType === 'caption') { + title = `Captioning: ${job?.job_ref || 'Loading...'}`; + } + + return ( + <> + {/* Fixed top bar */} + +
+ +
+
+

{title}

+
+
+ {job && ( + { + redirect('/jobs'); + }} + autoStartQueue={true} + /> + )} +
+ page.value === pageKey)?.mainCss}> + {status === 'loading' && job == null &&

Loading...

} + {status === 'error' && job == null &&

Error fetching job

} + {job && ( + <> + {pages.map(page => { + const Component = page.component; + return page.value === pageKey ? : null; + })} + + )} +
+
+ {pages.map(page => { + if (page.jobTypes && !page.jobTypes.includes(jobType)) { + return null; + } + if (page.value === 'plugin' && !hasPlugin) { + return null; + } + return ( + + ); + })} + {page?.menuItem && ( + <> +
+ + + )} +
+ + ); +} diff --git a/ai-toolkit/ui/src/app/jobs/new/SimpleJob.tsx b/ai-toolkit/ui/src/app/jobs/new/SimpleJob.tsx new file mode 100644 index 0000000000000000000000000000000000000000..5f9d9bc15a0eb803540ce95ad319fc5e6b76c4fa --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/new/SimpleJob.tsx @@ -0,0 +1,1643 @@ +'use client'; +import { useMemo } from 'react'; +import { + modelArchs, + ModelArch, + groupedModelOptions, + quantizationOptions, + defaultQtype, + jobTypeOptions, + SampleTags, +} from './options'; +import { defaultCompileOptions, defaultDatasetConfig } from './jobConfig'; +import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; +import { objectCopy, tagsToObj, objToTags } from '@/utils/basic'; +import { + TextInput, + TextAreaInput, + SelectInput, + Checkbox, + FormGroup, + NumberInput, + SliderInput, + CreatableSelectInput, +} from '@/components/formInputs'; +import Card from '@/components/Card'; +import { X, Copy, Wand2, SquareDashed } from 'lucide-react'; +import { openUpsamplePromptsModal, toAspectRatio } from '@/components/UpsamplePromptsModal'; +import { openPromptBoxEditor } from '@/components/PromptBoxEditorModal'; +import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; +import SampleControlImage from '@/components/SampleControlImage'; +import { FlipHorizontal2, FlipVertical2 } from 'lucide-react'; +import { handleModelArchChange } from './utils'; +import { IoFlaskSharp } from 'react-icons/io5'; +import { isMac } from '@/helpers/basic'; + +type Props = { + jobConfig: JobConfig; + setJobConfig: (value: any, key: string) => void; + status: 'idle' | 'saving' | 'success' | 'error'; + handleSubmit: (event: React.FormEvent) => void; + runId: string | null; + gpuIDs: string | null; + setGpuIDs: (value: string | null) => void; + gpuList: any; + datasetOptions: any; + isLoading?: boolean; +}; + +const isDev = process.env.NODE_ENV === 'development'; + +export default function SimpleJob({ + jobConfig, + setJobConfig, + handleSubmit, + status, + runId, + gpuIDs, + setGpuIDs, + gpuList, + datasetOptions, + isLoading, +}: Props) { + const modelArch = useMemo(() => { + return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; + }, [jobConfig.config.process[0].model.arch]); + + const jobType = useMemo(() => { + return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type); + }, [jobConfig.config.process[0].type]); + + const disableSections = useMemo(() => { + let sections: string[] = []; + if (modelArch?.disableSections) { + sections = sections.concat(modelArch.disableSections); + } + if (jobType?.disableSections) { + sections = sections.concat(jobType.disableSections); + } + return sections; + }, [modelArch, jobType]); + + const isVideoModel = !!(modelArch?.group === 'video'); + const isAudioModel = !!(modelArch?.group === 'audio'); + + const taggedSampleArr: Record[] | null = useMemo(() => { + if (!modelArch) return null; + if (!modelArch.sampleTags) return null; + if (!jobConfig.config.process[0].sample.samples) return null; + let sampleArr: any[] = []; + for (let i = 0; i < jobConfig.config.process[0].sample.samples.length; i++) { + const taggedPrompt = jobConfig.config.process[0].sample.samples[i].prompt; + const tagsObj = tagsToObj(taggedPrompt); + sampleArr.push(tagsObj); + } + return sampleArr; + }, [modelArch, jobConfig.config.process[0].sample.samples]); + + const modelArchTagSections: SampleTags[] | null = useMemo(() => { + if (!modelArch?.sampleTags) return null; + const maxPerGroup = 5; + let sections: SampleTags[] = []; + let subSection: SampleTags = {}; + for (const [tagKey, tag] of Object.entries(modelArch.sampleTags)) { + if ((tag.full && Object.keys(subSection).length > 0) || Object.keys(subSection).length >= maxPerGroup) { + // reset the sub section build if the next tag is full or max per group is reached + sections.push(subSection); + subSection = {}; + } + subSection[tagKey] = tag; + if (tag.full) { + // if the tag is full, push the section immediately and reset the sub section build + sections.push(subSection); + subSection = {}; + } + } + if (Object.keys(subSection).length > 0) { + sections.push(subSection); + } + return sections.length > 0 ? sections : null; + }, [modelArch]); + + const numTopCards = useMemo(() => { + let count = 4; // job settings, model config, target config, save config + if (modelArch?.additionalSections?.includes('model.multistage')) { + count += 1; // add multistage card + } + if (!disableSections.includes('model.quantize')) { + count += 1; // add quantization card + } + if (!disableSections.includes('slider')) { + count += 1; // add slider card + } + return count; + }, [modelArch, disableSections]); + + let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; + + if (numTopCards == 5) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; + } + if (numTopCards == 6) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; + } + + const numTrainingCols = useMemo(() => { + let count = 4; + if (!disableSections.includes('train.diff_output_preservation')) { + count += 1; + } + return count; + }, [disableSections]); + + let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'; + + if (numTrainingCols == 5) { + trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6'; + } + + const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { + const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; + if (!hasARA) { + return quantizationOptions; + } + let newQuantizationOptions = [ + { + label: 'Standard', + options: [quantizationOptions[0], quantizationOptions[1]], + }, + ]; + + // add ARAs if they exist for the model + let ARAs: SelectOption[] = []; + if (modelArch.accuracyRecoveryAdapters) { + for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { + ARAs.push({ value, label }); + } + } + if (ARAs.length > 0) { + newQuantizationOptions.push({ + label: 'Accuracy Recovery Adapters', + options: ARAs, + }); + } + + let additionalQuantizationOptions: SelectOption[] = []; + // add the quantization options if they are not already included + for (let i = 2; i < quantizationOptions.length; i++) { + const option = quantizationOptions[i]; + additionalQuantizationOptions.push(option); + } + if (additionalQuantizationOptions.length > 0) { + newQuantizationOptions.push({ + label: 'Additional Quantization Options', + options: additionalQuantizationOptions, + }); + } + return newQuantizationOptions; + }, [modelArch]); + + const showGPUSelect = !isMac(); + + let numDatasetCols = 4; + let numSampleTopCols = 4; + let datasetStyleClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'; + let sampleTopStyleClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'; + if (isVideoModel) { + numSampleTopCols += 1; + } + if (isAudioModel) { + numDatasetCols -= 1; + numSampleTopCols -= 1; + } + if (numDatasetCols == 3) { + datasetStyleClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6'; + } + if (numSampleTopCols == 5) { + sampleTopStyleClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6'; + } + if (numSampleTopCols == 3) { + sampleTopStyleClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6'; + } + return ( + <> +
+ {isLoading && ( +
+
+
+ Loading... +
+
+ )} +
+ + setJobConfig(value, 'config.name')} + placeholder="Enter training name" + disabled={runId !== null} + required + /> + {showGPUSelect && ( + setGpuIDs(value)} + options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} + /> + )} + {disableSections.includes('trigger_word') ? null : ( + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'config.process[0].trigger_word'); + }} + placeholder="" + required + /> + )} + + + {/* Model Configuration Section */} + + { + handleModelArchChange(jobConfig.config.process[0].model.arch, value, jobConfig, setJobConfig); + }} + options={groupedModelOptions} + /> + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'config.process[0].model.name_or_path'); + }} + placeholder="" + required + /> + {modelArch?.additionalSections?.includes('model.assistant_lora_path') && ( + { + if (value?.trim() === '') { + value = undefined; + } + setJobConfig(value, 'config.process[0].model.assistant_lora_path'); + }} + placeholder="" + /> + )} + {modelArch?.additionalSections?.includes('model.unconditional_lora_path') && ( + { + if (value?.trim() === '') { + value = undefined; + } + setJobConfig(value, 'config.process[0].model.unconditional_lora_path'); + }} + placeholder="" + /> + )} + {modelArch?.additionalSections?.includes('model.low_vram') && ( + + setJobConfig(value, 'config.process[0].model.low_vram')} + /> + + )} + {modelArch?.additionalSections?.includes('model.qie.match_target_res') && ( + setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')} + /> + )} + {modelArch?.additionalSections?.includes('model.layer_offloading') && !isMac() && ( + <> + + Layer Offloading {' '} + + } + checked={jobConfig.config.process[0].model.layer_offloading || false} + onChange={value => setJobConfig(value, 'config.process[0].model.layer_offloading')} + docKey="model.layer_offloading" + /> + {jobConfig.config.process[0].model.layer_offloading && ( +
+ + setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_transformer_percent') + } + min={0} + max={100} + step={1} + /> + + setJobConfig(value * 0.01, 'config.process[0].model.layer_offloading_text_encoder_percent') + } + min={0} + max={100} + step={1} + /> +
+ )} + + )} +
+ {disableSections.includes('model.quantize') ? null : ( + + { + if (value === '') { + setJobConfig(false, 'config.process[0].model.quantize'); + value = defaultQtype; + } else { + setJobConfig(true, 'config.process[0].model.quantize'); + } + setJobConfig(value, 'config.process[0].model.qtype'); + }} + options={transformerQuantizationOptions} + /> + {!disableSections.includes('model.quantize_te') && ( + { + if (value === '') { + setJobConfig(false, 'config.process[0].model.quantize_te'); + value = defaultQtype; + } else { + setJobConfig(true, 'config.process[0].model.quantize_te'); + } + setJobConfig(value, 'config.process[0].model.qtype_te'); + }} + options={quantizationOptions} + /> + )} + + <> + + { + setJobConfig(value, 'config.process[0].model.compile'); + if (value) { + for (const key in defaultCompileOptions) { + setJobConfig((defaultCompileOptions as any)[key], `config.process[0].model.${key}`); + } + } else { + for (const key in defaultCompileOptions) { + setJobConfig(undefined, `config.process[0].model.${key}`); + } + } + }} + /> + + )} + {modelArch?.additionalSections?.includes('model.multistage') && ( + + + setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')} + /> + setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')} + /> + + setJobConfig(value, 'config.process[0].train.switch_boundary_every')} + placeholder="eg. 1" + docKey={'train.switch_boundary_every'} + min={1} + required + /> + + )} + + setJobConfig(value, 'config.process[0].network.type')} + options={[ + { value: 'lora', label: 'LoRA' }, + { value: 'lokr', label: 'LoKr' }, + ]} + /> + {jobConfig.config.process[0].network?.type == 'lokr' && ( + setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')} + options={[ + { value: '-1', label: 'Auto' }, + { value: '4', label: '4' }, + { value: '8', label: '8' }, + { value: '16', label: '16' }, + { value: '32', label: '32' }, + ]} + /> + )} + {jobConfig.config.process[0].network?.type == 'lora' && ( + <> + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.linear'); + setJobConfig(value, 'config.process[0].network.linear_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + required + /> + {disableSections.includes('network.conv') ? null : ( + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.conv'); + setJobConfig(value, 'config.process[0].network.conv_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + /> + )} + + )} + + {!disableSections.includes('slider') && ( + + setJobConfig(value, 'config.process[0].slider.target_class')} + placeholder="eg. person" + /> + setJobConfig(value, 'config.process[0].slider.positive_prompt')} + placeholder="eg. person who is happy" + /> + setJobConfig(value, 'config.process[0].slider.negative_prompt')} + placeholder="eg. person who is sad" + /> + setJobConfig(value, 'config.process[0].slider.anchor_class')} + placeholder="" + /> + + )} + + setJobConfig(value, 'config.process[0].save.dtype')} + options={[ + { value: 'bf16', label: 'BF16' }, + { value: 'fp16', label: 'FP16' }, + { value: 'fp32', label: 'FP32' }, + ]} + /> + setJobConfig(value, 'config.process[0].save.save_every')} + placeholder="eg. 250" + min={1} + required + /> + setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')} + placeholder="eg. 4" + min={1} + required + /> + +
+
+ +
+
+ setJobConfig(value, 'config.process[0].train.batch_size')} + placeholder="eg. 4" + min={1} + required + /> + setJobConfig(value, 'config.process[0].train.gradient_accumulation')} + placeholder="eg. 1" + min={1} + required + /> + setJobConfig(value, 'config.process[0].train.steps')} + placeholder="eg. 2000" + min={1} + required + /> +
+
+ setJobConfig(value, 'config.process[0].train.optimizer')} + options={[ + { value: 'adafactor', label: 'Adafactor' }, + { value: 'adam', label: 'Adam' }, + { value: 'adamw', label: 'AdamW' }, + { value: 'adamw8bit', label: 'AdamW8Bit' }, + { value: 'automagic', label: 'Automagic' }, + { value: 'automagic2', label: 'Automagic v2' }, + { value: 'prodigyopt', label: 'Prodigy' }, + { value: 'prodigy8bit', label: 'Prodigy8Bit' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.lr')} + placeholder="eg. 0.0001" + min={0} + required + /> + setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')} + placeholder="eg. 0.0001" + min={0} + required + /> +
+
+ {disableSections.includes('train.timestep_type') ? null : ( + setJobConfig(value, 'config.process[0].train.timestep_type')} + options={[ + { value: 'sigmoid', label: 'Sigmoid' }, + { value: 'linear', label: 'Linear' }, + { value: 'shift', label: 'Shift' }, + { value: 'weighted', label: 'Weighted' }, + ]} + /> + )} + setJobConfig(value, 'config.process[0].train.content_or_style')} + options={[ + { value: 'balanced', label: 'Balanced' }, + { value: 'content', label: 'High Noise' }, + { value: 'style', label: 'Low Noise' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.loss_type')} + options={[ + { value: 'mse', label: 'Mean Squared Error' }, + { value: 'mae', label: 'Mean Absolute Error' }, + { value: 'wavelet', label: 'Wavelet' }, + { value: 'stepped', label: 'Stepped Recovery' }, + ]} + /> + {modelArch?.additionalSections?.includes('train.audio_loss_multiplier') && ( + setJobConfig(value, 'config.process[0].train.audio_loss_multiplier')} + placeholder="eg. 1.0" + docKey={'train.audio_loss_multiplier'} + min={0} + /> + )} +
+
+ + setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} + /> + + {jobConfig.config.process[0].train.ema_config?.use_ema && ( + setJobConfig(value, 'config.process[0].train.ema_config.ema_decay')} + placeholder="eg. 0.99" + min={0} + /> + )} + + + {!disableSections.includes('train.unload_text_encoder') && ( + { + setJobConfig(value, 'config.process[0].train.unload_text_encoder'); + if (value) { + setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); + } + }} + /> + )} + { + setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); + if (value) { + setJobConfig(false, 'config.process[0].train.unload_text_encoder'); + } + }} + /> + +
+
+ {disableSections.includes('train.diff_output_preservation') || + disableSections.includes('train.blank_prompt_preservation') ? null : ( + + <> + + )} + {disableSections.includes('train.diff_output_preservation') ? null : ( + <> + { + setJobConfig(value, 'config.process[0].train.diff_output_preservation'); + if (value && jobConfig.config.process[0].train.blank_prompt_preservation) { + // only one can be enabled at a time + setJobConfig(false, 'config.process[0].train.blank_prompt_preservation'); + } + }} + /> + {jobConfig.config.process[0].train.diff_output_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_class') + } + placeholder="eg. woman" + /> + + )} + + )} + {disableSections.includes('train.blank_prompt_preservation') ? null : ( + <> + { + setJobConfig(value, 'config.process[0].train.blank_prompt_preservation'); + if (value && jobConfig.config.process[0].train.diff_output_preservation) { + // only one can be enabled at a time + setJobConfig(false, 'config.process[0].train.diff_output_preservation'); + } + }} + /> + {jobConfig.config.process[0].train.blank_prompt_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + + )} + + )} +
+
+
+
+
+ +
+
+ { + let newValue = value == false ? undefined : value; + setJobConfig(newValue, 'config.process[0].train.do_differential_guidance'); + if (!newValue) { + setJobConfig(undefined, 'config.process[0].train.differential_guidance_scale'); + } else if ( + jobConfig.config.process[0].train.differential_guidance_scale === undefined || + jobConfig.config.process[0].train.differential_guidance_scale === null + ) { + // set default differential guidance scale to 3.0 + setJobConfig(3.0, 'config.process[0].train.differential_guidance_scale'); + } + }} + /> + {jobConfig.config.process[0].train.differential_guidance_scale && ( + <> + setJobConfig(value, 'config.process[0].train.differential_guidance_scale')} + placeholder="eg. 3.0" + min={0} + /> + + )} +
+
+
+
+
+ + <> + {jobConfig.config.process[0].datasets.map((dataset, i) => ( +
+
+ + +
+

Dataset {i + 1}

+
+
+ setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} + options={datasetOptions} + /> + {modelArch?.additionalSections?.includes('datasets.control_path') && ( + + setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + )} + {modelArch?.additionalSections?.includes('datasets.multi_control_paths') && ( + <> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_1`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_2`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_3`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + )} + setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)} + placeholder="eg. 1.0" + /> + setJobConfig(value, `config.process[0].datasets[${i}].num_repeats`)} + placeholder="eg. 1" + docKey={'dataset.num_repeats'} + /> +
+
+ setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)} + placeholder="eg. A photo of a cat" + /> + setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)} + placeholder="eg. 0.05" + min={0} + required + /> + setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)} + options={[ + { value: 'txt', label: 'txt' }, + { value: 'json', label: 'json' }, + { value: 'caption', label: 'caption' }, + ]} + /> + + {modelArch?.additionalSections?.includes('datasets.num_frames') && !dataset.auto_frame_count && ( + setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)} + placeholder="eg. 41" + min={1} + required + /> + )} +
+
+ + + setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`) + } + /> + setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} + /> + {modelArch?.additionalSections?.includes('datasets.auto_frame_count') && ( + setJobConfig(value, `config.process[0].datasets[${i}].auto_frame_count`)} + docKey="datasets.auto_frame_count" + /> + )} + {modelArch?.additionalSections?.includes('datasets.do_i2v') && ( + setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} + docKey="datasets.do_i2v" + /> + )} + {modelArch?.additionalSections?.includes('datasets.do_audio') && ( + { + if (!value) { + setJobConfig(undefined, `config.process[0].datasets[${i}].do_audio`); + } else { + setJobConfig(value, `config.process[0].datasets[${i}].do_audio`); + } + }} + docKey="datasets.do_audio" + /> + )} + {modelArch?.additionalSections?.includes('datasets.audio_normalize') && ( + { + if (!value) { + setJobConfig(undefined, `config.process[0].datasets[${i}].audio_normalize`); + } else { + setJobConfig(value, `config.process[0].datasets[${i}].audio_normalize`); + } + }} + docKey="datasets.audio_normalize" + /> + )} + {modelArch?.additionalSections?.includes('datasets.audio_preserve_pitch') && ( + { + if (!value) { + setJobConfig(undefined, `config.process[0].datasets[${i}].audio_preserve_pitch`); + } else { + setJobConfig(value, `config.process[0].datasets[${i}].audio_preserve_pitch`); + } + }} + docKey="datasets.audio_preserve_pitch" + /> + )} + + {!isAudioModel && ( + + + Flip X + + } + checked={dataset.flip_x || false} + onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} + /> + + Flip Y + + } + checked={dataset.flip_y || false} + onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} + /> + + )} +
+ {!isAudioModel && ( +
+ +
+ {[ + [256, 512, 768, 1024], + [1280, 1328, 1536, 2048], + ].map(resGroup => ( +
+ {resGroup.map(res => ( + { + const resolutions = dataset.resolution.includes(res) + ? dataset.resolution.filter(r => r !== res) + : [...dataset.resolution, res]; + setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); + }} + /> + ))} +
+ ))} +
+
+
+ )} +
+
+ ))} + + +
+
+
+ +
+
+ setJobConfig(value, 'config.process[0].sample.sample_every')} + placeholder="eg. 250" + min={1} + required + /> + setJobConfig(value, 'config.process[0].sample.sampler')} + options={[ + { value: 'flowmatch', label: 'FlowMatch' }, + { value: 'ddpm', label: 'DDPM' }, + ]} + /> + setJobConfig(value, 'config.process[0].sample.guidance_scale')} + placeholder="eg. 1.0" + className="pt-2" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.sample_steps')} + placeholder="eg. 1" + className="pt-2" + min={1} + required + /> +
+ + {!isAudioModel && ( +
+ setJobConfig(value, 'config.process[0].sample.width')} + placeholder="eg. 1024" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.height')} + placeholder="eg. 1024" + className="pt-2" + min={0} + required + /> + {isVideoModel && ( +
+ setJobConfig(value, 'config.process[0].sample.num_frames')} + placeholder="eg. 0" + className="pt-2" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.fps')} + placeholder="eg. 0" + className="pt-2" + min={0} + required + /> +
+ )} +
+ )} + +
+ setJobConfig(value, 'config.process[0].sample.seed')} + placeholder="eg. 0" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.walk_seed')} + /> +
+
+ +
+ { + setJobConfig(value, 'config.process[0].train.skip_first_sample'); + // cannot do both, so disable the other + if (value) { + setJobConfig(false, 'config.process[0].train.force_first_sample'); + } + }} + /> +
+
+ { + setJobConfig(value, 'config.process[0].train.force_first_sample'); + // cannot do both, so disable the other + if (value) { + setJobConfig(false, 'config.process[0].train.skip_first_sample'); + } + }} + /> +
+
+ { + setJobConfig(value, 'config.process[0].train.disable_sampling'); + // cannot do both, so disable the other + if (value) { + setJobConfig(false, 'config.process[0].train.force_first_sample'); + } + }} + /> +
+
+
+
+
+ + {modelArch?.additionalSections?.includes('ideogram_4_prompt') && ( + + )} +
+ {jobConfig.config.process[0].sample.samples.map((sample, i) => ( +
+
+
+
+
+ {modelArch?.sampleTags && taggedSampleArr && modelArchTagSections ? ( + <> + {modelArchTagSections.map((sampleTagSection, sti) => ( +
+ {Object.entries(sampleTagSection).map(([tagKey, tag]) => ( +
+ {tag.type === 'text' && ( + { + let taggedSample = { ...taggedSampleArr[i] }; + taggedSample[tagKey] = value; + setJobConfig( + objToTags(taggedSample), + `config.process[0].sample.samples[${i}].prompt`, + ); + }} + placeholder={`Enter ${tag.title.toLowerCase()}`} + /> + )} + {tag.type === 'multiline' && ( + { + let taggedSample = { ...taggedSampleArr[i] }; + taggedSample[tagKey] = value; + setJobConfig( + objToTags(taggedSample), + `config.process[0].sample.samples[${i}].prompt`, + ); + }} + placeholder={`Enter ${tag.title.toLowerCase()}`} + /> + )} + {tag.type === 'number' && ( + { + let taggedSample = { ...taggedSampleArr[i] }; + taggedSample[tagKey] = value; + setJobConfig( + objToTags(taggedSample), + `config.process[0].sample.samples[${i}].prompt`, + ); + }} + placeholder={`Enter ${tag.title.toLowerCase()}`} + /> + )} +
+ ))} +
+ ))} + + ) : ( + <> + {modelArch?.hasMultiLinePrompts ? ( + setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)} + placeholder="Enter prompt" + required + /> + ) : ( + setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)} + placeholder="Enter prompt" + required + /> + )} + + )} + + {modelArch?.additionalSections?.includes('ideogram_4_prompt') && ( +
+ +
+ )} + +
+ {!isAudioModel && ( + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].width; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`); + } else { + console.warn('Invalid width value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.width} (default)`} + /> + )} + {!isAudioModel && ( + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].height; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`); + } else { + console.warn('Invalid height value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.height} (default)`} + /> + )} + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].seed; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`); + } else { + console.warn('Invalid seed value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`} + /> + { + // remove any non-numeric, - or . characters + value = value.replace(/[^0-9.-]/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].network_multiplier; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + // set it as a string + setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`); + return; + } + }} + placeholder={`1.0 (default)`} + /> +
+
+ {modelArch?.additionalSections?.includes('datasets.multi_control_paths') && ( + +
+ {['ctrl_img_1', 'ctrl_img_2', 'ctrl_img_3'].map((ctrlKey, ctrl_idx) => ( + { + if (!imagePath) { + let newSamples = objectCopy(jobConfig.config.process[0].sample.samples); + delete newSamples[i][ctrlKey as keyof typeof sample]; + setJobConfig(newSamples, 'config.process[0].sample.samples'); + } else { + setJobConfig(imagePath, `config.process[0].sample.samples[${i}].${ctrlKey}`); + } + }} + /> + ))} +
+
+ )} + {modelArch?.additionalSections?.includes('sample.ctrl_img') && ( + { + if (!imagePath) { + let newSamples = objectCopy(jobConfig.config.process[0].sample.samples); + delete newSamples[i].ctrl_img; + setJobConfig(newSamples, 'config.process[0].sample.samples'); + } else { + setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`); + } + }} + /> + )} +
+
+
+
+ +
+
+
+ ))} + +
+
+ + {status === 'success' &&

Training saved successfully!

} + {status === 'error' &&

Error saving training. Please try again.

} + + + + ); +} diff --git a/ai-toolkit/ui/src/app/jobs/new/jobConfig.ts b/ai-toolkit/ui/src/app/jobs/new/jobConfig.ts new file mode 100644 index 0000000000000000000000000000000000000000..ec01502832b1396be69912341272c807f42de600 --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/new/jobConfig.ts @@ -0,0 +1,168 @@ +'use client'; +import { isMac } from '@/helpers/basic'; +import { defaultSampleConfig } from '@/helpers/defaultSamples'; +import { JobConfig, SampleConfig, DatasetConfig, SliderConfig } from '@/types'; + +export const defaultDatasetConfig: DatasetConfig = { + folder_path: '/path/to/images/folder', + mask_path: null, + mask_min_value: 0.1, + default_caption: '', + caption_ext: 'txt', + caption_dropout_rate: 0.05, + cache_latents_to_disk: false, + is_reg: false, + network_weight: 1, + resolution: [512, 768, 1024], + controls: [], + shrink_video_to_frames: true, + num_frames: 1, + flip_x: false, + flip_y: false, + num_repeats: 1, +}; + +export const defaultSliderConfig: SliderConfig = { + guidance_strength: 3.0, + anchor_strength: 1.0, + positive_prompt: 'person who is happy', + negative_prompt: 'person who is sad', + target_class: 'person', + anchor_class: '', +}; + +export const defaultCompileOptions = { + block_compile: true, +}; + +export const defaultJobConfig: JobConfig = { + job: 'extension', + config: { + name: 'my_first_lora_v1', + process: [ + { + type: 'diffusion_trainer', + training_folder: 'output', + sqlite_db_path: './aitk_db.db', + device: 'cuda', + trigger_word: null, + performance_log_every: 10, + network: { + type: 'lora', + linear: 32, + linear_alpha: 32, + conv: 16, + conv_alpha: 16, + lokr_full_rank: true, + lokr_factor: -1, + network_kwargs: { + ignore_if_contains: [], + }, + }, + save: { + dtype: 'bf16', + save_every: 250, + max_step_saves_to_keep: 4, + save_format: 'diffusers', + push_to_hub: false, + }, + datasets: [defaultDatasetConfig], + train: { + batch_size: 1, + bypass_guidance_embedding: true, + steps: 3000, + gradient_accumulation: 1, + train_unet: true, + train_text_encoder: false, + gradient_checkpointing: true, + noise_scheduler: 'flowmatch', + optimizer: 'adamw8bit', + timestep_type: 'sigmoid', + content_or_style: 'balanced', + optimizer_params: { + weight_decay: 1e-4, + }, + unload_text_encoder: false, + cache_text_embeddings: false, + lr: 0.0001, + ema_config: { + use_ema: false, + ema_decay: 0.99, + }, + skip_first_sample: false, + force_first_sample: false, + disable_sampling: false, + dtype: 'bf16', + diff_output_preservation: false, + diff_output_preservation_multiplier: 1.0, + diff_output_preservation_class: 'person', + switch_boundary_every: 1, + loss_type: 'mse', + }, + logging: { + log_every: 1, + use_ui_logger: true, + }, + model: { + name_or_path: 'ostris/Flex.1-alpha', + quantize: true, + qtype: 'qfloat8', + quantize_te: true, + qtype_te: 'qfloat8', + arch: 'flex1', + low_vram: false, + model_kwargs: {}, + compile: false, + }, + sample: defaultSampleConfig, + }, + ], + }, + meta: { + name: '[name]', + version: '1.0', + }, +}; + +export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => { + // upgrade prompt strings to samples + if ( + jobConfig?.config?.process && + jobConfig.config.process[0]?.sample && + Array.isArray(jobConfig.config.process[0].sample.prompts) && + jobConfig.config.process[0].sample.prompts.length > 0 + ) { + let newSamples = []; + for (const prompt of jobConfig.config.process[0].sample.prompts) { + newSamples.push({ + prompt: prompt, + }); + } + jobConfig.config.process[0].sample.samples = newSamples; + delete jobConfig.config.process[0].sample.prompts; + } + + // upgrade job from ui_trainer to diffusion_trainer + if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') { + jobConfig.config.process[0].type = 'diffusion_trainer'; + } + + if ('auto_memory' in jobConfig.config.process[0].model) { + jobConfig.config.process[0].model.layer_offloading = (jobConfig.config.process[0].model.auto_memory || + false) as boolean; + delete jobConfig.config.process[0].model.auto_memory; + } + + if (!('logging' in jobConfig.config.process[0])) { + //@ts-ignore + jobConfig.config.process[0].logging = { + log_every: 1, + use_ui_logger: true, + }; + } + if (isMac()) { + jobConfig.config.process[0].device = 'mps'; + } + + return jobConfig; +}; diff --git a/ai-toolkit/ui/src/app/jobs/new/options.ts b/ai-toolkit/ui/src/app/jobs/new/options.ts new file mode 100644 index 0000000000000000000000000000000000000000..74c733599371dd0f135b0fb4c48f03913f1e972e --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/new/options.ts @@ -0,0 +1,1202 @@ +import { GroupedSelectOption, SelectOption, JobConfig } from '@/types'; +import { defaultSliderConfig } from './jobConfig'; +import { defaultAudioSampleConfig, defaultSampleConfig, defaultIdeogramSamplesConfig } from '@/helpers/defaultSamples'; + +type Control = 'depth' | 'line' | 'pose' | 'inpaint'; + +type DisableableSections = + | 'model.quantize' + | 'model.quantize_te' + | 'train.timestep_type' + | 'network.conv' + | 'trigger_word' + | 'train.diff_output_preservation' + | 'train.blank_prompt_preservation' + | 'train.unload_text_encoder' + | 'slider'; + +type AdditionalSections = + | 'datasets.control_path' + | 'datasets.multi_control_paths' + | 'datasets.do_i2v' + | 'datasets.do_audio' + | 'datasets.audio_normalize' + | 'datasets.audio_preserve_pitch' + | 'datasets.auto_frame_count' + | 'sample.ctrl_img' + | 'sample.multi_ctrl_imgs' + | 'train.audio_loss_multiplier' + | 'datasets.num_frames' + | 'model.multistage' + | 'model.layer_offloading' + | 'model.low_vram' + | 'model.qie.match_target_res' + | 'model.assistant_lora_path' + | 'model.unconditional_lora_path' + | 'ideogram_4_prompt'; + +type ModelGroup = 'image' | 'instruction' | 'video' | 'experimental' | 'audio'; + +export type SampleTag = { + title: string; + type: 'text' | 'multiline' | 'number' + full?: boolean; +} + +export interface SampleTags { + [key: string]: SampleTag; +} + +export interface ModelArch { + name: string; + label: string; + group: ModelGroup; + controls?: Control[]; + isVideoModel?: boolean; + hasMultiLinePrompts?: boolean; + defaults?: { [key: string]: any }; + disableSections?: DisableableSections[]; + additionalSections?: AdditionalSections[]; + accuracyRecoveryAdapters?: { [key: string]: string }; + sampleTags?: SampleTags; +} + +const defaultNameOrPath = ''; +const defaultLinearRank = 32 + +export const modelArchs: ModelArch[] = [ + { + name: 'flux', + label: 'FLUX.1', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'flux_kontext', + label: 'FLUX.1-Kontext-dev', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img'], + }, + { + name: 'flex1', + label: 'Flex.1', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.bypass_guidance_embedding': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'flex2', + label: 'Flex.2', + group: 'image', + controls: ['depth', 'line', 'pose', 'inpaint'], + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.model_kwargs': [ + { + invert_inpaint_mask_chance: 0.2, + inpaint_dropout: 0.5, + control_dropout: 0.5, + inpaint_random_chance: 0.2, + do_random_inpainting: true, + random_blur_mask: true, + random_dialate_mask: true, + }, + {}, + ], + 'config.process[0].train.bypass_guidance_embedding': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'chroma', + label: 'Chroma', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['lodestones/Chroma1-Base', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'zeta_chroma', + label: 'Zeta Chroma', + group: 'experimental', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['lodestones/Zeta-Chroma/zeta-chroma-base-x0-pixel-dino-distance.safetensors', defaultNameOrPath], + 'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'wan21:1b', + label: 'Wan 2.1 (1.3B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].datasets[x].fps': [16, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames', 'model.low_vram', 'datasets.auto_frame_count'], + }, + { + name: 'wan21_i2v:14b480p', + label: 'Wan 2.1 I2V (14B-480P)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].datasets[x].fps': [16, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.auto_frame_count'], + }, + { + name: 'wan21_i2v:14b', + label: 'Wan 2.1 I2V (14B-720P)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].datasets[x].fps': [16, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.auto_frame_count'], + }, + { + name: 'wan21:14b', + label: 'Wan 2.1 (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].datasets[x].fps': [16, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames', 'model.low_vram', 'datasets.auto_frame_count'], + }, + { + name: 'wan22_14b:t2v', + label: 'Wan 2.2 (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].datasets[x].fps': [16, undefined], + 'config.process[0].model.model_kwargs': [ + { + train_high_noise: true, + train_low_noise: true, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading', 'datasets.auto_frame_count'], + accuracyRecoveryAdapters: { + // '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors', + '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors', + }, + }, + { + name: 'wan22_14b_i2v', + label: 'Wan 2.2 I2V (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].datasets[x].fps': [16, undefined], + 'config.process[0].model.model_kwargs': [ + { + train_high_noise: true, + train_low_noise: true, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'sample.ctrl_img', + 'datasets.num_frames', + 'model.low_vram', + 'model.multistage', + 'model.layer_offloading', + 'datasets.auto_frame_count', + ], + accuracyRecoveryAdapters: { + '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors', + }, + }, + { + name: 'wan22_5b', + label: 'Wan 2.2 TI2V (5B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [121, 1], + 'config.process[0].sample.fps': [24, 1], + 'config.process[0].sample.width': [768, 1024], + 'config.process[0].sample.height': [768, 1024], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].datasets[x].do_i2v': [true, undefined], + 'config.process[0].datasets[x].fps': [24, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v', 'datasets.auto_frame_count'], + }, + { + name: 'lumina2', + label: 'Lumina2', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'qwen_image', + label: 'Qwen-Image', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram', 'model.layer_offloading'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', + }, + }, + { + name: 'qwen_image:2512', + label: 'Qwen-Image-2512', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-2512', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram', 'model.layer_offloading'], + // Training an ARA now, the other one will not work + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_2512_torchao_uint3.safetensors', + '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/qwen_image_2512_torchao_uint4.safetensors', + }, + }, + { + name: 'qwen_image_edit', + label: 'Qwen-Image-Edit', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram', 'model.layer_offloading'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', + }, + }, + { + name: 'qwen_image_edit_plus', + label: 'Qwen-Image-Edit-2509', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2509', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv', 'train.unload_text_encoder'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors', + }, + }, + { + name: 'qwen_image_edit_plus:2511', + label: 'Qwen-Image-Edit-2511', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2511', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv', 'train.unload_text_encoder'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2511_torchao_uint3.safetensors', + }, + }, + { + name: 'hidream', + label: 'HiDream', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.lr': [0.0002, 0.0001], + 'config.process[0].train.timestep_type': ['shift', 'sigmoid'], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/hidream_i1_full_torchao_uint3.safetensors', + }, + }, + { + name: 'hidream_e1', + label: 'HiDream E1', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.lr': [0.0001, 0.0001], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'], + }, + { + name: 'sdxl', + label: 'SDXL', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [false, false], + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].sample.guidance_scale': [6, 4], + }, + disableSections: ['model.quantize', 'train.timestep_type'], + }, + { + name: 'sd15', + label: 'SD 1.5', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath], + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].sample.width': [512, 1024], + 'config.process[0].sample.height': [512, 1024], + 'config.process[0].sample.guidance_scale': [6, 4], + }, + disableSections: ['model.quantize', 'train.timestep_type'], + }, + { + name: 'omnigen2', + label: 'OmniGen2', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img'], + }, + { + name: 'flux2', + label: 'FLUX.2', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-dev', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, + { + name: 'zimage:turbo', + label: 'Z-Image Turbo (w/ Training Adapter)', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Tongyi-MAI/Z-Image-Turbo', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.assistant_lora_path': [ + 'ostris/zimage_turbo_training_adapter/zimage_turbo_training_adapter_v2.safetensors', + undefined, + ], + 'config.process[0].sample.guidance_scale': [1, 4], + 'config.process[0].sample.sample_steps': [8, 25], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram', 'model.layer_offloading', 'model.assistant_lora_path'], + }, + { + name: 'zimage', + label: 'Z-Image', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Tongyi-MAI/Z-Image', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].sample.sample_steps': [30, 25], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram', 'model.layer_offloading'], + }, + { + name: 'zimage:deturbo', + label: 'Z-Image De-Turbo (De-Distilled)', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/Z-Image-De-Turbo', defaultNameOrPath], + 'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].sample.guidance_scale': [3, 4], + 'config.process[0].sample.sample_steps': [25, 25], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram', 'model.layer_offloading'], + }, + { + name: 'ltx2', + label: 'LTX-2', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Lightricks/LTX-2', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [121, 1], + 'config.process[0].sample.fps': [24, 1], + 'config.process[0].sample.width': [768, 1024], + 'config.process[0].sample.height': [768, 1024], + 'config.process[0].train.audio_loss_multiplier': [1.0, undefined], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].datasets[x].do_i2v': [false, undefined], + 'config.process[0].datasets[x].do_audio': [true, undefined], + 'config.process[0].datasets[x].fps': [24, undefined], + 'config.process[0].datasets[x].auto_frame_count': [false, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v', 'train.audio_loss_multiplier', 'datasets.auto_frame_count'], + }, + { + name: 'ltx2.3', + label: 'LTX-2.3', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Lightricks/LTX-2.3/ltx-2.3-22b-dev.safetensors', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [121, 1], + 'config.process[0].sample.fps': [24, 1], + 'config.process[0].sample.width': [768, 1024], + 'config.process[0].sample.height': [768, 1024], + 'config.process[0].train.audio_loss_multiplier': [1.0, undefined], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].datasets[x].cache_latents_to_disk': [true, false], + 'config.process[0].datasets[x].do_i2v': [false, undefined], + 'config.process[0].datasets[x].do_audio': [true, undefined], + 'config.process[0].datasets[x].fps': [24, undefined], + 'config.process[0].datasets[x].auto_frame_count': [false, undefined], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.layer_offloading', 'model.low_vram', 'datasets.do_audio', 'datasets.audio_normalize', 'datasets.audio_preserve_pitch', 'datasets.do_i2v', 'train.audio_loss_multiplier', 'datasets.auto_frame_count'], + }, + { + name: 'flux2_klein_4b', + label: 'FLUX.2-klein-base-4B', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-4B', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, + { + name: 'ernie_image', + label: 'ERNIE-Image', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['baidu/ERNIE-Image', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'flux2_klein_9b', + label: 'FLUX.2-klein-base-9B', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-klein-base-9B', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, + { + name: 'ace_step_15_xl', + label: 'ACE-Step 1.5 XL', + group: 'audio', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/ace_step_1.5_ComfyUI_files/ace_step_1.5_xl_base_aio.safetensors', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].sample': [defaultAudioSampleConfig, defaultSampleConfig], + }, + sampleTags: { + "CAPTION": { + title: "Audio Prompt", + type: "text", + full: true, + }, + "LYRICS": { + title: "Lyrics", + type: "multiline", + full: true, + }, + "BPM": { + title: "BPM", + type: "number", + }, + "KEYSCALE": { + title: "Key Scale", + type: "text", + }, + "TIMESIGNATURE": { + title: "Time Signature", + type: "text", + }, + "DURATION": { + title: "Duration (sec)", + type: "number", + }, + "LANGUAGE": { + title: "Language", + type: "text", + }, + }, + disableSections: ['network.conv'], + additionalSections: [ + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'ace_step_15', + label: 'ACE-Step 1.5', + group: 'audio', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/ace_step_1.5_ComfyUI_files/ace_step_1.5_base_aio.safetensors', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + 'config.process[0].sample': [defaultAudioSampleConfig, defaultSampleConfig], + }, + sampleTags: { + "CAPTION": { + title: "Audio Prompt", + type: "text", + full: true, + }, + "LYRICS": { + title: "Lyrics", + type: "multiline", + full: true, + }, + "BPM": { + title: "BPM", + type: "number", + }, + "KEYSCALE": { + title: "Key Scale", + type: "text", + }, + "TIMESIGNATURE": { + title: "Time Signature", + type: "text", + }, + "DURATION": { + title: "Duration (sec)", + type: "number", + }, + "LANGUAGE": { + title: "Language", + type: "text", + }, + }, + disableSections: ['network.conv'], + additionalSections: [ + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'nucleus_image', + label: 'Nucleus-Image', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['NucleusAI/Nucleus-Image', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['img_mlp.experts', 'img_mlp.gate'], []], + 'config.process[0].network.linear': [128, defaultLinearRank], + 'config.process[0].network.linear_alpha': [128, defaultLinearRank], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram'], + }, + { + name: 'hidream_o1', + label: 'HiDream-O1', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-O1-Image', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [false, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].train.max_loss': [1.0, undefined], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['lm_head', 'patch_embed', 'visual'], []], + 'config.process[0].network.transformer_only': [false, undefined], + 'config.process[0].sample.width': [2048, 1024], + 'config.process[0].sample.height': [2048, 1024], + 'config.process[0].model.model_kwargs': [ + { + noise_scale_inference: 8.0, + noise_scale: 8.0, + }, + {}, + ], + }, + disableSections: [ + 'network.conv', + 'model.quantize_te', + 'train.unload_text_encoder', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'zimage_l2p', + label: 'Z-Image L2P (pixel space)', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['zhen-nan/L2P/model-1k-merge.safetensors', defaultNameOrPath], + 'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + }, + disableSections: [ + 'network.conv', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'ideogram4', + label: 'Ideogram4', + group: 'experimental', + defaults: { + 'config.process[0].model.name_or_path': ['ideogram-ai/ideogram-4-fp8', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample': [defaultIdeogramSamplesConfig, defaultSampleConfig], + 'config.process[0].model.unconditional_lora_path': [ + 'ostris/ideogram_4_unconditional_lora/ideogram_4_unconditional_lora_r16.safetensors', + undefined, + ], + }, + disableSections: [ + 'network.conv', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + 'ideogram_4_prompt', + 'model.unconditional_lora_path', + ], + hasMultiLinePrompts: true, + }, + { + name: 'prx_pixel', + label: 'PRXPixel (pixel space)', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['Photoroom/prxpixel-t2i', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + }, + disableSections: [ + 'network.conv', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'krea2', + label: 'Krea 2 (raw)', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['krea/Krea-2-Raw', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + }, + disableSections: [ + 'network.conv', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'krea2:turbo', + label: 'Krea 2 Turbo (w/ Training Adapter)', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['krea/Krea-2-Turbo', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].model.assistant_lora_path': [ + 'ostris/krea2_turbo_training_adapter/krea2_turbo_training_adapter_v1.safetensors', + undefined, + ], + 'config.process[0].sample.guidance_scale': [1, 4], + 'config.process[0].sample.sample_steps': [8, 25], + }, + disableSections: [ + 'network.conv', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + 'model.assistant_lora_path' + ], + }, + { + name: 'boogu_image', + label: 'Boogu Image', + group: 'image', + defaults: { + 'config.process[0].model.name_or_path': ['Boogu/Boogu-Image-0.1-Base', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + }, + disableSections: [ + 'network.conv', + ], + additionalSections: [ + 'model.low_vram', + 'model.layer_offloading', + ], + }, + { + name: 'boogu_image_edit', + label: 'Boogu Image Edit', + group: 'instruction', + defaults: { + 'config.process[0].model.name_or_path': ['Boogu/Boogu-Image-0.1-Edit', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].network.conv': [undefined, 16], + 'config.process[0].network.conv_alpha': [undefined, 16], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].model.model_kwargs': [ + { + match_target_res: false, + }, + {}, + ], + }, + disableSections: [ + 'network.conv', 'train.unload_text_encoder', + ], + additionalSections: [ + 'datasets.multi_control_paths', + 'sample.multi_ctrl_imgs', + 'model.low_vram', + 'model.layer_offloading', + 'model.qie.match_target_res', + ], + }, +].sort((a, b) => { + // Sort by label, case-insensitive + return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); +}) as any; + +export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => { + const group = acc.find(g => g.label === arch.group); + if (group) { + group.options.push({ value: arch.name, label: arch.label }); + } else { + acc.push({ + label: arch.group, + options: [{ value: arch.name, label: arch.label }], + }); + } + return acc; +}, [] as GroupedSelectOption[]); + +export const quantizationOptions: SelectOption[] = [ + { value: '', label: '- NONE -' }, + { value: 'qfloat8', label: 'float8 (default)' }, + { value: 'uint7', label: '7 bit' }, + { value: 'uint6', label: '6 bit' }, + { value: 'uint5', label: '5 bit' }, + { value: 'uint4', label: '4 bit' }, + { value: 'uint3', label: '3 bit' }, + { value: 'uint2', label: '2 bit' }, +]; + +export const defaultQtype = 'qfloat8'; + +interface JobTypeOption extends SelectOption { + disableSections?: DisableableSections[]; + processSections?: string[]; + onActivate?: (config: JobConfig) => JobConfig; + onDeactivate?: (config: JobConfig) => JobConfig; +} + +export const jobTypeOptions: JobTypeOption[] = [ + { + value: 'diffusion_trainer', + label: 'LoRA Trainer', + disableSections: ['slider'], + }, + { + value: 'concept_slider', + label: 'Concept Slider', + disableSections: ['trigger_word', 'train.diff_output_preservation'], + onActivate: (config: JobConfig) => { + // add default slider config + config.config.process[0].slider = { ...defaultSliderConfig }; + return config; + }, + onDeactivate: (config: JobConfig) => { + // remove slider config + delete config.config.process[0].slider; + return config; + }, + }, +]; diff --git a/ai-toolkit/ui/src/app/jobs/new/page.tsx b/ai-toolkit/ui/src/app/jobs/new/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..d41808fbb6afc1a8fbc15b384363e5d698092a11 --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/new/page.tsx @@ -0,0 +1,336 @@ +'use client'; + +import { useEffect, useRef, useState } from 'react'; +import { useSearchParams, useRouter } from 'next/navigation'; +import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig'; +import { jobTypeOptions } from './options'; +import { JobConfig } from '@/types'; +import { objectCopy } from '@/utils/basic'; +import { useNestedState, setNestedValue } from '@/utils/hooks'; +import { SelectInput } from '@/components/formInputs'; +import useSettings from '@/hooks/useSettings'; +import useGPUInfo from '@/hooks/useGPUInfo'; +import useDatasetList from '@/hooks/useDatasetList'; +import YAML from 'yaml'; +import path from 'path'; +import { TopBar, MainContent } from '@/components/layout'; +import { Button } from '@headlessui/react'; +import { FaChevronLeft } from 'react-icons/fa'; +import SimpleJob from './SimpleJob'; +import AdvancedConfigEditor from '@/components/AdvancedConfigEditor'; +import ErrorBoundary from '@/components/ErrorBoundary'; +import { apiClient } from '@/utils/api'; + +const isDev = process.env.NODE_ENV === 'development'; + +export default function TrainingForm() { + const router = useRouter(); + const searchParams = useSearchParams(); + const runId = searchParams.get('id'); + const cloneId = searchParams.get('cloneId'); + const [gpuIDs, setGpuIDs] = useState(null); + const { settings, isSettingsLoaded } = useSettings(); + const { gpuList, isGPUInfoLoaded } = useGPUInfo(); + const { datasets, status: datasetFetchStatus } = useDatasetList(); + const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]); + const [showAdvancedView, setShowAdvancedView] = useState(false); + + const [jobConfig, setJobConfig] = useNestedState(objectCopy(migrateJobConfig(defaultJobConfig))); + const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); + const fileInputRef = useRef(null); + + const handleImportConfig = () => { + fileInputRef.current?.click(); + }; + + const handleFileSelected = (e: React.ChangeEvent) => { + const file = e.target.files?.[0]; + if (!file) return; + + const reader = new FileReader(); + reader.onload = () => { + try { + const text = reader.result as string; + let parsed: any; + if (file.name.endsWith('.json') || file.name.endsWith('.jsonc')) { + parsed = JSON.parse(text.replace(/\/\/.*$/gm, '').replace(/\/\*[\s\S]*?\*\//g, '')); + } else { + parsed = YAML.parse(text); + } + + // Set required fields (same pattern as AdvancedJob.handleChange) + try { + parsed.config.process[0].sqlite_db_path = './aitk_db.db'; + parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; + parsed.config.process[0].device = 'cuda'; + parsed.config.process[0].performance_log_every = 10; + } catch (err) { + console.warn('Could not set required fields on imported config:', err); + } + + migrateJobConfig(parsed); + setJobConfig(parsed); + } catch (err) { + console.error('Failed to parse config file:', err); + alert('Failed to parse config file. Please check the file format.'); + } + }; + reader.readAsText(file); + + // Reset so the same file can be re-imported + e.target.value = ''; + }; + + useEffect(() => { + if (!isSettingsLoaded) return; + if (datasetFetchStatus !== 'success') return; + + const datasetOptions = datasets.map(name => ({ value: path.join(settings.DATASETS_FOLDER, name), label: name })); + setDatasetOptions(datasetOptions); + + if (datasetOptions.length > 0) { + const defaultDatasetPath = defaultDatasetConfig.folder_path; + // Use functional updater so we check the *current* state, not a stale closure + setJobConfig((prev: JobConfig) => { + let updated = prev; + for (let i = 0; i < prev.config.process[0].datasets.length; i++) { + if (prev.config.process[0].datasets[i].folder_path === defaultDatasetPath) { + updated = setNestedValue(updated, datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`); + } + } + return updated; + }); + } + }, [datasets, settings, isSettingsLoaded, datasetFetchStatus]); + + // clone existing job + useEffect(() => { + if (cloneId) { + apiClient + .get(`/api/jobs?id=${cloneId}`) + .then(res => res.data) + .then(data => { + console.log('Clone Training:', data); + setGpuIDs(data.gpu_ids); + const newJobConfig = migrateJobConfig(JSON.parse(data.job_config)); + newJobConfig.config.name = `${newJobConfig.config.name}_copy`; + setJobConfig(newJobConfig); + }) + .catch(error => console.error('Error fetching training:', error)); + } + }, [cloneId]); + + useEffect(() => { + if (runId) { + apiClient + .get(`/api/jobs?id=${runId}`) + .then(res => res.data) + .then(data => { + console.log('Training:', data); + setGpuIDs(data.gpu_ids); + setJobConfig(migrateJobConfig(JSON.parse(data.job_config))); + }) + .catch(error => console.error('Error fetching training:', error)); + } + }, [runId]); + + useEffect(() => { + if (isGPUInfoLoaded) { + if (gpuIDs === null && gpuList.length > 0) { + setGpuIDs(`${gpuList[0].index}`); + } + } + }, [gpuList, isGPUInfoLoaded]); + + useEffect(() => { + if (isSettingsLoaded) { + setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder'); + } + }, [settings, isSettingsLoaded]); + + const saveJob = async () => { + if (status === 'saving') return; + setStatus('saving'); + + apiClient + .post('/api/jobs', { + id: runId, + name: jobConfig.config.name, + gpu_ids: gpuIDs, + job_config: jobConfig, + }) + .then(res => { + setStatus('success'); + if (runId) { + router.push(`/jobs/${runId}`); + } else { + router.push(`/jobs/${res.data.id}`); + } + }) + .catch(error => { + if (error.response?.status === 409) { + alert('Training name already exists. Please choose a different name.'); + } else { + alert('Failed to save job. Please try again.'); + } + console.log('Error saving training:', error); + }) + .finally(() => + setTimeout(() => { + setStatus('idle'); + }, 2000), + ); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + saveJob(); + }; + + return ( + <> + +
+ +
+
+

+ {runId ? 'Edit Training Job' : 'New Training Job'} +

+
+
+ {showAdvancedView && ( + <> +
+ setGpuIDs(value)} + options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} + /> +
+
+
+ +
+
+ + )} + {!showAdvancedView && ( + <> +
+ { + // undo current job type changes + const currentOption = jobTypeOptions.find( + option => option.value === jobConfig?.config.process[0].type, + ); + if (currentOption && currentOption.onDeactivate) { + setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig))); + } + const option = jobTypeOptions.find(option => option.value === value); + if (option) { + if (option.onActivate) { + setJobConfig(option.onActivate(objectCopy(jobConfig))); + } + jobTypeOptions.forEach(opt => { + if (opt.value !== option.value && opt.onDeactivate) { + setJobConfig(opt.onDeactivate(objectCopy(jobConfig))); + } + }); + } + setJobConfig(value, 'config.process[0].type'); + }} + options={jobTypeOptions} + /> +
+
+ + )} + +
+ +
+
+ +
+
+ + + + {showAdvancedView ? ( +
+ { + try { + parsed.config.process[0].sqlite_db_path = './aitk_db.db'; + parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; + parsed.config.process[0].device = 'cuda'; + parsed.config.process[0].performance_log_every = 10; + } catch (e) { + console.warn(e); + } + return migrateJobConfig(parsed); + }} + /> +
+ ) : ( + + + Advanced job detected. Please switch to advanced view to continue. +
+ } + > + + + +
+ + )} + + ); +} diff --git a/ai-toolkit/ui/src/app/jobs/new/utils.ts b/ai-toolkit/ui/src/app/jobs/new/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..98b2471eae45d80134ff0a7af2c835f65a341ee0 --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/new/utils.ts @@ -0,0 +1,150 @@ +import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; +import { modelArchs, ModelArch } from './options'; +import { objectCopy } from '@/utils/basic'; + +const expandDatasetDefaults = ( + defaults: { [key: string]: any }, + numDatasets: number, +): { [key: string]: any } => { + // expands the defaults for datasets[x] to datasets[0], datasets[1], etc. + const expandedDefaults: { [key: string]: any } = { ...defaults }; + for (const key in defaults) { + if (key.includes('datasets[x].')) { + for (let i = 0; i < numDatasets; i++) { + const datasetKey = key.replace('datasets[x].', `datasets[${i}].`); + const v = defaults[key]; + expandedDefaults[datasetKey] = Array.isArray(v) ? [...v] : objectCopy(v); + } + delete expandedDefaults[key]; + } + } + return expandedDefaults; +}; + +export const handleModelArchChange = ( + currentArchName: string, + newArchName: string, + jobConfig: JobConfig, + setJobConfig: (value: any, key: string) => void, +) => { + const currentArch = modelArchs.find(a => a.name === currentArchName); + if (!currentArch || currentArch.name === newArchName) { + return; + } + + // update the defaults when a model is selected + const newArch = modelArchs.find(model => model.name === newArchName); + + // update vram setting + if (!newArch?.additionalSections?.includes('model.low_vram')) { + setJobConfig(false, 'config.process[0].model.low_vram'); + } + + // handle layer offloading setting + if (!newArch?.additionalSections?.includes('model.layer_offloading')) { + if ('layer_offloading' in jobConfig.config.process[0].model) { + const newModel = objectCopy(jobConfig.config.process[0].model); + delete newModel.layer_offloading; + delete newModel.layer_offloading_text_encoder_percent; + delete newModel.layer_offloading_transformer_percent; + setJobConfig(newModel, 'config.process[0].model'); + } + } else { + // set to false if not set + if (!('layer_offloading' in jobConfig.config.process[0].model)) { + setJobConfig(false, 'config.process[0].model.layer_offloading'); + setJobConfig(1.0, 'config.process[0].model.layer_offloading_text_encoder_percent'); + setJobConfig(1.0, 'config.process[0].model.layer_offloading_transformer_percent'); + } + } + + const numDatasets = jobConfig.config.process[0].datasets.length; + + let currentDefaults = expandDatasetDefaults(currentArch.defaults || {}, numDatasets); + let newDefaults = expandDatasetDefaults(newArch?.defaults || {}, numDatasets); + + // set new model + setJobConfig(newArchName, 'config.process[0].model.arch'); + + // update datasets + const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; + const hasMultiControlPaths = newArch?.additionalSections?.includes('datasets.multi_control_paths') || false; + const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; + const hasAutoFrameCount = newArch?.additionalSections?.includes('datasets.auto_frame_count') || false; + const controls = newArch?.controls ?? []; + const datasets = jobConfig.config.process[0].datasets.map(dataset => { + const newDataset = objectCopy(dataset); + newDataset.controls = controls; + if (hasMultiControlPaths) { + // make sure the config has the multi control paths + newDataset.control_path_1 = newDataset.control_path_1 || null; + newDataset.control_path_2 = newDataset.control_path_2 || null; + newDataset.control_path_3 = newDataset.control_path_3 || null; + // if we previously had a single control path and now + // we selected a multi control model + if (newDataset.control_path && newDataset.control_path !== '') { + // only set if not overwriting + if (!newDataset.control_path_1) { + newDataset.control_path_1 = newDataset.control_path; + } + } + delete newDataset.control_path; // remove single control path + } else if (hasControlPath) { + newDataset.control_path = newDataset.control_path || null; + if (newDataset.control_path_1 && newDataset.control_path_1 !== '') { + newDataset.control_path = newDataset.control_path_1; + } + if ('control_path_1' in newDataset) { + delete newDataset.control_path_1; + } + if ('control_path_2' in newDataset) { + delete newDataset.control_path_2; + } + if ('control_path_3' in newDataset) { + delete newDataset.control_path_3; + } + } else { + // does not have control images + if ('control_path' in newDataset) { + delete newDataset.control_path; + } + if ('control_path_1' in newDataset) { + delete newDataset.control_path_1; + } + if ('control_path_2' in newDataset) { + delete newDataset.control_path_2; + } + if ('control_path_3' in newDataset) { + delete newDataset.control_path_3; + } + } + if (!hasNumFrames) { + newDataset.num_frames = 1; // reset num_frames if not applicable + } + if (!hasAutoFrameCount) { + delete newDataset.auto_frame_count; + } + return newDataset; + }); + setJobConfig(datasets, 'config.process[0].datasets'); + + // update samples + const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; + const samples = jobConfig.config.process[0].sample.samples.map(sample => { + const newSample = objectCopy(sample); + if (!hasSampleCtrlImg) { + delete newSample.ctrl_img; // remove ctrl_img if not applicable + } + return newSample; + }); + setJobConfig(samples, 'config.process[0].sample.samples'); + + // revert defaults from previous model + for (const key in currentDefaults) { + setJobConfig(currentDefaults[key][1], key); + } + + for (const key in newDefaults) { + setJobConfig(newDefaults[key][0], key); + } +}; diff --git a/ai-toolkit/ui/src/app/jobs/page.tsx b/ai-toolkit/ui/src/app/jobs/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..570427061e8bdd043b0a4e39bb60a92fa462ef25 --- /dev/null +++ b/ai-toolkit/ui/src/app/jobs/page.tsx @@ -0,0 +1,30 @@ +'use client'; + +import JobsTable from '@/components/JobsTable'; +import { TopBar, MainContent } from '@/components/layout'; +import Link from 'next/link'; + +export default function Dashboard() { + return ( + <> + +
+

Queue

+
+
+
+ + + New Job + New Training Job + +
+
+ + + + + ); +} diff --git a/ai-toolkit/ui/src/app/layout.tsx b/ai-toolkit/ui/src/app/layout.tsx new file mode 100644 index 0000000000000000000000000000000000000000..e4a91fa3d15572803ca2be683cbb1f2cb4b87569 --- /dev/null +++ b/ai-toolkit/ui/src/app/layout.tsx @@ -0,0 +1,73 @@ +import type { Metadata } from 'next'; +import { Inter } from 'next/font/google'; +import './globals.css'; +import Sidebar from '@/components/Sidebar'; +import { ThemeProvider } from '@/components/ThemeProvider'; +import ConfirmModal from '@/components/ConfirmModal'; +import { Suspense } from 'react'; +import AuthWrapper from '@/components/AuthWrapper'; +import DocModal from '@/components/DocModal'; +import os from 'os'; +import { CaptionDatasetModal } from '@/components/CaptionDatasetModal'; +import MergeLoRAsModal from '@/components/MergeLoRAsModal'; +import UpsamplePromptsModal from '@/components/UpsamplePromptsModal'; +import PromptBoxEditorModal from '@/components/PromptBoxEditorModal'; + +export const dynamic = 'force-dynamic'; + +const inter = Inter({ subsets: ['latin'] }); + +export const metadata: Metadata = { + title: 'Ostris - AI Toolkit', + description: 'A toolkit for building AI things.', +}; + +export const viewport = { + width: 'device-width', + initialScale: 1, + maximumScale: 5, +}; + +export default function RootLayout({ children }: { children: React.ReactNode }) { + // Check if the AI_TOOLKIT_AUTH environment variable is set + const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false; + + const platform = os.platform(); + + return ( + + + +