Sasha (Spock) commited on
Commit
7a18c1b
·
1 Parent(s): c8cd022

materialize ai-toolkit (binaries + bak files removed); fix Dockerfile with PYTHONPATH

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +16 -10
  2. ai-toolkit/.gitignore +187 -0
  3. ai-toolkit/.gitmodules +0 -0
  4. ai-toolkit/FAQ.md +10 -0
  5. ai-toolkit/LICENSE +21 -0
  6. ai-toolkit/README.md +316 -0
  7. ai-toolkit/build_and_push_docker +29 -0
  8. ai-toolkit/build_and_push_docker_dev +21 -0
  9. ai-toolkit/config/examples/extract.example.yml +75 -0
  10. ai-toolkit/config/examples/generate.example.yaml +60 -0
  11. ai-toolkit/config/examples/mod_lora_scale.yaml +48 -0
  12. ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
  13. ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
  14. ai-toolkit/config/examples/train_flex_redux.yaml +112 -0
  15. ai-toolkit/config/examples/train_full_fine_tune_flex.yaml +107 -0
  16. ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml +99 -0
  17. ai-toolkit/config/examples/train_lora_chroma_24gb.yaml +104 -0
  18. ai-toolkit/config/examples/train_lora_flex2_24gb.yaml +165 -0
  19. ai-toolkit/config/examples/train_lora_flex_24gb.yaml +101 -0
  20. ai-toolkit/config/examples/train_lora_flux_24gb.yaml +96 -0
  21. ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml +106 -0
  22. ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
  23. ai-toolkit/config/examples/train_lora_hidream_48.yaml +112 -0
  24. ai-toolkit/config/examples/train_lora_lumina.yaml +96 -0
  25. ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml +94 -0
  26. ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml +95 -0
  27. ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml +105 -0
  28. ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml +102 -0
  29. ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml +97 -0
  30. ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml +101 -0
  31. ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml +90 -0
  32. ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml +111 -0
  33. ai-toolkit/config/examples/train_slider.example.yml +230 -0
  34. ai-toolkit/dgx_instructions.md +84 -0
  35. ai-toolkit/dgx_requirements.txt +13 -0
  36. ai-toolkit/docker-compose.yml +25 -0
  37. ai-toolkit/docker/Dockerfile +108 -0
  38. ai-toolkit/docker/start.sh +70 -0
  39. ai-toolkit/extensions/example/ExampleMergeModels.py +129 -0
  40. ai-toolkit/extensions/example/__init__.py +25 -0
  41. ai-toolkit/extensions/example/config/config.example.yaml +48 -0
  42. ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
  43. ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
  44. ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
  45. ai-toolkit/extensions_built_in/advanced_generator/__init__.py +59 -0
  46. ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
  47. ai-toolkit/extensions_built_in/audio_models/__init__.py +7 -0
  48. ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py +1 -0
  49. ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py +335 -0
  50. ai-toolkit/extensions_built_in/audio_models/ace_step/src/__init__.py +0 -0
Dockerfile CHANGED
@@ -1,22 +1,28 @@
1
-
2
  FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime
3
 
4
  WORKDIR /app
 
 
5
  RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
6
 
7
- # Install ai-toolkit
8
- RUN git clone https://github.com/ostris/ai-toolkit.git /app/ai-toolkit
9
- WORKDIR /app/ai-toolkit
10
- RUN git submodule update --init --recursive
11
- RUN pip install --no-cache-dir -e .
 
 
 
 
 
12
 
13
- # Install HF Hub
14
- RUN pip install --no-cache-dir huggingface_hub
15
 
16
  # Copy training files
17
  COPY . /app/
18
 
19
- # Pre-download FLUX model and assistant LoRA
20
- RUN python -c "from huggingface_hub import snapshot_download; snapshot_download('Niansuh/FLUX.1-schnell', cache_dir='/app/hf_cache'); snapshot_download('ostris/FLUX.1-schnell-training-adapter', cache_dir='/app/hf_cache')"
21
 
22
  CMD ["python", "/app/train_cloud.py"]
 
 
1
  FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime
2
 
3
  WORKDIR /app
4
+
5
+ # System deps
6
  RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
7
 
8
+ # Pre-baked ai-toolkit (numpy 2.5 / dctorch / torchdiffeq / torchsde / clip forks, 3.14-compatible)
9
+ # Copied from local checkout so HF doesn't have to clone the submodule / run pip install -e .
10
+ COPY ai-toolkit /app/ai-toolkit
11
+
12
+ # Make ai-toolkit importable. Upstream ai-toolkit ships without setup.py / pyproject.toml,
13
+ # so pip install -e . would fail. We add it to PYTHONPATH instead.
14
+ ENV PYTHONPATH=/app:/app/ai-toolkit
15
+ ENV PYTHONUNBUFFERED=1
16
+ ENV HF_HOME=/app/hf_cache
17
+ ENV TRANSFORMERS_CACHE=/app/hf_cache
18
 
19
+ # Install runtime deps
20
+ RUN pip install --no-cache-dir huggingface_hub hf_transfer
21
 
22
  # Copy training files
23
  COPY . /app/
24
 
25
+ # Pre-download FLUX base + training adapter at build time so they're in the image cache
26
+ RUN python -c "import os; os.environ['HF_HUB_ENABLE_HF_TRANSFER']='1'; from huggingface_hub import snapshot_download; snapshot_download('Niansuh/FLUX.1-schnell'); snapshot_download('ostris/FLUX.1-schnell-training-adapter')"
27
 
28
  CMD ["python", "/app/train_cloud.py"]
ai-toolkit/.gitignore ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ .python
126
+ .node
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
163
+
164
+ /env.sh
165
+ /models
166
+ /datasets
167
+ /custom/*
168
+ !/custom/.gitkeep
169
+ /.tmp
170
+ /venv.bkp
171
+ /venv.*
172
+ /config/*
173
+ !/config/examples
174
+ !/config/_PUT_YOUR_CONFIGS_HERE).txt
175
+ /output/*
176
+ !/output/.gitkeep
177
+ /extensions/*
178
+ !/extensions/example
179
+ /temp
180
+ /wandb
181
+ .vscode/settings.json
182
+ .DS_Store
183
+ ._.DS_Store
184
+ aitk_db.db
185
+ /notes.md
186
+ /data
187
+ .claude
ai-toolkit/.gitmodules ADDED
File without changes
ai-toolkit/FAQ.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ WIP. Will continue to add things as they are needed.
4
+
5
+ ## FLUX.1 Training
6
+
7
+ #### How much VRAM is required to train a lora on FLUX.1?
8
+
9
+ 24GB minimum is required.
10
+
ai-toolkit/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ostris, LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ai-toolkit/README.md ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ostris AI Toolkit
2
+
3
+ AI Toolkit is an easy to use all in one training suite for diffusion models. I try to support all the latest models on consumer grade hardware. Image and video models. It can be run as a GUI or CLI. It is designed to be easy to use but still have every feature imaginable. Free and open source.
4
+
5
+
6
+
7
+ ## Supported Models
8
+
9
+ ### Image
10
+ - [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (FLUX.1)
11
+ - [black-forest-labs/FLUX.2-dev](https://huggingface.co/black-forest-labs/FLUX.2-dev) (FLUX.2)
12
+ - [black-forest-labs/FLUX.2-klein-base-4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) (FLUX.2-klein-base-4B)
13
+ - [black-forest-labs/FLUX.2-klein-base-9B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B) (FLUX.2-klein-base-9B)
14
+ - [ostris/Flex.1-alpha](https://huggingface.co/ostris/Flex.1-alpha) (Flex.1)
15
+ - [ostris/Flex.2-preview](https://huggingface.co/ostris/Flex.2-preview) (Flex.2)
16
+ - [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) (Chroma)
17
+ - [Alpha-VLLM/Lumina-Image-2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) (Lumina2)
18
+ - [Qwen/Qwen-Image](https://huggingface.co/Qwen/Qwen-Image) (Qwen-Image)
19
+ - [Qwen/Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512) (Qwen-Image-2512)
20
+ - [HiDream-ai/HiDream-I1-Full](https://huggingface.co/HiDream-ai/HiDream-I1-Full) (HiDream I1)
21
+ - [OmniGen2/OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) (OmniGen2)
22
+ - [Tongyi-MAI/Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) (Z-Image Turbo)
23
+ - [Tongyi-MAI/Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image) (Z-Image)
24
+ - [ostris/Z-Image-De-Turbo](https://huggingface.co/ostris/Z-Image-De-Turbo) (Z-Image De-Turbo)
25
+ - [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) (SDXL)
26
+ - [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) (SD 1.5)
27
+ - [baidu/ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image) (ERNIE-Image)
28
+ - [NucleusAI/Nucleus-Image](https://huggingface.co/NucleusAI/Nucleus-Image) (Nucleus-Image)
29
+ - [HiDream-ai/HiDream-O1-Image](https://huggingface.co/HiDream-ai/HiDream-O1-Image) (HiDream O1)
30
+ - [Photoroom/prxpixel-t2i](https://huggingface.co/Photoroom/prxpixel-t2i) (PRXPixel)
31
+
32
+ ### Instruction / Edit
33
+ - [black-forest-labs/FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) (FLUX.1-Kontext-dev)
34
+ - [Qwen/Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit) (Qwen-Image-Edit)
35
+ - [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) (Qwen-Image-Edit-2509)
36
+ - [Qwen/Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) (Qwen-Image-Edit-2511)
37
+ - [HiDream-ai/HiDream-E1-1](https://huggingface.co/HiDream-ai/HiDream-E1-1) (HiDream E1)
38
+
39
+ ### Video
40
+ - [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) (Wan 2.1 1.3B)
41
+ - [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers) (Wan 2.1 I2V 14B-480P)
42
+ - [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers) (Wan 2.1 I2V 14B-720P)
43
+ - [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) (Wan 2.1 14B)
44
+ - [Wan-AI/Wan2.2-T2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) (Wan 2.2 14B)
45
+ - [Wan-AI/Wan2.2-I2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) (Wan 2.2 I2V 14B)
46
+ - [Wan-AI/Wan2.2-TI2V-5B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) (Wan 2.2 TI2V 5B)
47
+ - [Lightricks/LTX-2](https://huggingface.co/Lightricks/LTX-2) (LTX-2)
48
+ - [Lightricks/LTX-2.3](https://huggingface.co/Lightricks/LTX-2.3) (LTX-2.3)
49
+ - [krea/Krea-2-Raw](https://huggingface.co/krea/Krea-2-Raw) (Krea 2)
50
+
51
+ ### Audio
52
+ - [ACE-Step/Ace-Step1.5](https://huggingface.co/ACE-Step/Ace-Step1.5) (Ace Step 1.5)
53
+ - [ACE-Step/acestep-v15-xl-base](https://huggingface.co/ACE-Step/acestep-v15-xl-base) (Ace Step 1.5 XL)
54
+
55
+ ### Experimental
56
+ - [lodestones/Zeta-Chroma](https://huggingface.co/lodestones/Zeta-Chroma) (Zeta Chroma)
57
+ - [ideogram-ai/ideogram-4-fp8](https://huggingface.co/ideogram-ai/ideogram-4-fp8) (Ideogram 4 FP8)
58
+
59
+ ## Installation
60
+
61
+ Requirements:
62
+ - python >=3.10 (3.12 recommended)
63
+ - Nvidia GPU with enough ram to do what you need
64
+ - python venv
65
+ - git
66
+
67
+
68
+ Linux:
69
+ ```bash
70
+ git clone https://github.com/ostris/ai-toolkit.git
71
+ cd ai-toolkit
72
+ python3 -m venv venv
73
+ source venv/bin/activate
74
+ # install torch first
75
+ pip3 install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
76
+ pip3 install -r requirements.txt
77
+ ```
78
+
79
+ For devices running **DGX OS** (including DGX Spark), follow [these](dgx_instructions.md) instructions.
80
+
81
+
82
+ Windows:
83
+
84
+ If you are having issues with Windows. I recommend using the easy install script at [https://github.com/Tavris1/AI-Toolkit-Easy-Install](https://github.com/Tavris1/AI-Toolkit-Easy-Install)
85
+
86
+ ```bash
87
+ git clone https://github.com/ostris/ai-toolkit.git
88
+ cd ai-toolkit
89
+ python -m venv venv
90
+ .\venv\Scripts\activate
91
+ pip install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
92
+ pip install -r requirements.txt
93
+ ```
94
+
95
+ MacOS:
96
+
97
+ Experimental support for Silicon Macs is available. I do not have a Mac with enough RAM to fully test this
98
+ so please let me know if there are issues. There is a convience script to install and run on MacOS
99
+ locates at `./run_mac.zsh` that will install the dependencies locally and run the UI. To run this,
100
+ do the following:
101
+
102
+ ```bash
103
+ git clone https://github.com/ostris/ai-toolkit.git
104
+ cd ai-toolkit
105
+ chmod +x run_mac.zsh
106
+ ./run_mac.zsh
107
+ ```
108
+
109
+
110
+ # AI Toolkit UI
111
+
112
+ <img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
113
+
114
+ The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server.
115
+
116
+ ## Running the UI
117
+
118
+ Requirements:
119
+ - Node.js > 20
120
+
121
+ The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below
122
+ will install / update the UI and it's dependencies and start the UI.
123
+
124
+ ```bash
125
+ cd ui
126
+ npm run build_and_start
127
+ ```
128
+
129
+ You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
130
+
131
+ ## Securing the UI
132
+
133
+ If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token.
134
+ You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access
135
+ the UI. You can set this when starting the UI like so:
136
+
137
+ ```bash
138
+ # Linux
139
+ AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start
140
+
141
+ # Windows
142
+ set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start
143
+
144
+ # Windows Powershell
145
+ $env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start
146
+ ```
147
+
148
+ ### Training
149
+ 1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
150
+ 2. Edit the file following the comments in the file
151
+ 3. Run the file like so `python run.py config/whatever_you_want.yml`
152
+
153
+ A folder with the name and the training folder from the config file will be created when you start. It will have all
154
+ checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
155
+ from the last checkpoint.
156
+
157
+ IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
158
+
159
+ ### Need help?
160
+
161
+ Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
162
+ and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
163
+ and I will answer when I can.
164
+
165
+ ## Ostris Cloud
166
+
167
+ You can use many cloud providers to rent GPUs. If you want to help support this project in the largest way possible, please consider using [Ostris Cloud](https://cloud.ostris.com). Ostris Cloud is owned and operated by me, Ostris, and every dollar earned goes directly back into funding the development of this project.
168
+
169
+ <a href="https://cloud.ostris.com" target="_blank"><img src="https://cloud.ostris.com/api/og" alt="Ostris Cloud" style="max-width:100%;width:600px;height:auto;"></a>
170
+
171
+
172
+ ## Training in RunPod
173
+ If you would like to use Runpod, but have not signed up yet, please consider using [my Runpod affiliate link](https://runpod.io?ref=h0y9jyr2) to help support this project.
174
+
175
+
176
+ I maintain an official Runpod Pod template here which can be accessed [here](https://console.runpod.io/deploy?template=0fqzfjy6f3&ref=h0y9jyr2).
177
+
178
+ I have also created a short video showing how to get started using AI Toolkit with Runpod [here](https://youtu.be/HBNeS-F6Zz8).
179
+
180
+ ## Training in Modal
181
+
182
+ ### 1. Setup
183
+ #### ai-toolkit:
184
+ ```
185
+ git clone https://github.com/ostris/ai-toolkit.git
186
+ cd ai-toolkit
187
+ git submodule update --init --recursive
188
+ python -m venv venv
189
+ source venv/bin/activate
190
+ pip install torch
191
+ pip install -r requirements.txt
192
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
193
+ ```
194
+ #### Modal:
195
+ - Run `pip install modal` to install the modal Python package.
196
+ - Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
197
+
198
+ #### Hugging Face:
199
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
200
+ - Run `huggingface-cli login` and paste your token.
201
+
202
+ ### 2. Upload your dataset
203
+ - Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
204
+
205
+ ### 3. Configs
206
+ - Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
207
+ - Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
208
+
209
+ ### 4. Edit run_modal.py
210
+ - Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
211
+
212
+ ```
213
+ code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
214
+ ```
215
+ - Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
216
+
217
+ ### 5. Training
218
+ - Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
219
+ - You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
220
+ - Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
221
+
222
+ ### 6. Saving the model
223
+ - Check contents of the volume by running `modal volume ls flux-lora-models`.
224
+ - Download the content by running `modal volume get flux-lora-models your-model-name`.
225
+ - Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
226
+
227
+ ### Screenshot from Modal
228
+
229
+ <img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
230
+
231
+ ---
232
+
233
+ ## Dataset Preparation
234
+
235
+ Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
236
+ formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
237
+ but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
238
+ You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
239
+ replaced.
240
+
241
+ Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
242
+ The loader will automatically resize them and can handle varying aspect ratios.
243
+
244
+
245
+ ## Training Specific Layers
246
+
247
+ To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
248
+ used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
249
+ network kwargs like so:
250
+
251
+ ```yaml
252
+ network:
253
+ type: "lora"
254
+ linear: 128
255
+ linear_alpha: 128
256
+ network_kwargs:
257
+ only_if_contains:
258
+ - "transformer.single_transformer_blocks.7.proj_out"
259
+ - "transformer.single_transformer_blocks.20.proj_out"
260
+ ```
261
+
262
+ The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
263
+ the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
264
+ For instance to only train the `single_transformer` for FLUX.1, you can use the following:
265
+
266
+ ```yaml
267
+ network:
268
+ type: "lora"
269
+ linear: 128
270
+ linear_alpha: 128
271
+ network_kwargs:
272
+ only_if_contains:
273
+ - "transformer.single_transformer_blocks."
274
+ ```
275
+
276
+ You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
277
+
278
+
279
+ ```yaml
280
+ network:
281
+ type: "lora"
282
+ linear: 128
283
+ linear_alpha: 128
284
+ network_kwargs:
285
+ ignore_if_contains:
286
+ - "transformer.single_transformer_blocks."
287
+ ```
288
+
289
+ `ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
290
+ if will be ignored.
291
+
292
+ ## LoKr Training
293
+
294
+ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so:
295
+
296
+ ```yaml
297
+ network:
298
+ type: "lokr"
299
+ lokr_full_rank: true
300
+ lokr_factor: 8
301
+ ```
302
+
303
+ Everything else should work the same including layer targeting.
304
+
305
+
306
+ ## Support My Work
307
+
308
+ If you enjoy my projects or use them commercially, please consider sponsoring me. Every bit helps! 💖
309
+
310
+ <a href="https://ostris.com/sponsors" target="_blank"><img src="https://ostris.com/wp-content/uploads/2025/05/support-banner2.png" alt="Support my work" style="max-width:100%;height:auto;"></a>
311
+
312
+ ### Current Sponsors
313
+
314
+ All of these people / organizations are the ones who selflessly make this project possible. Thank you!!
315
+
316
+ <a href="https://ostris.com/sponsors"><img src="https://ostris.com/sponsors.svg" alt="Sponsors" style="width:100%;height:auto;"></a>
ai-toolkit/build_and_push_docker ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Extract version from version.py
4
+ if [ -f "version.py" ]; then
5
+ VERSION=$(python3 -c "from version import VERSION; print(VERSION)")
6
+ echo "Building version: $VERSION"
7
+ else
8
+ echo "Error: version.py not found. Please create a version.py file with VERSION defined."
9
+ exit 1
10
+ fi
11
+
12
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
13
+ echo "Building version: $VERSION and latest"
14
+ # wait 2 seconds
15
+ sleep 2
16
+
17
+ # Build the image with cache busting
18
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
19
+
20
+ # Tag with version and latest
21
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
22
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:latest
23
+
24
+ # Push both tags
25
+ echo "Pushing images to Docker Hub..."
26
+ docker push ostris/aitoolkit:$VERSION
27
+ docker push ostris/aitoolkit:latest
28
+
29
+ echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest"
ai-toolkit/build_and_push_docker_dev ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ VERSION=dev
4
+ GIT_COMMIT=dev
5
+
6
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
7
+ echo "Building version: $VERSION"
8
+ # wait 2 seconds
9
+ sleep 2
10
+
11
+ # Build the image with cache busting
12
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
13
+
14
+ # Tag with version and latest
15
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
16
+
17
+ # Push both tags
18
+ echo "Pushing images to Docker Hub..."
19
+ docker push ostris/aitoolkit:$VERSION
20
+
21
+ echo "Successfully built and pushed ostris/aitoolkit:$VERSION"
ai-toolkit/config/examples/extract.example.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # this is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to read and write
4
+ # plus it has comments which is nice for documentation
5
+ job: extract # tells the runner what to do
6
+ config:
7
+ # the name will be used to create a folder in the output folder
8
+ # it will also replace any [name] token in the rest of this config
9
+ name: name_of_your_model
10
+ # can be hugging face model, a .ckpt, or a .safetensors
11
+ base_model: "/path/to/base/model.safetensors"
12
+ # can be hugging face model, a .ckpt, or a .safetensors
13
+ extract_model: "/path/to/model/to/extract/trained.safetensors"
14
+ # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
15
+ output_folder: "/path/to/output/folder"
16
+ is_v2: false
17
+ dtype: fp16 # saved dtype
18
+ device: cpu # cpu, cuda:0, etc
19
+
20
+ # processes can be chained like this to run multiple in a row
21
+ # they must all use same models above, but great for testing different
22
+ # sizes and typed of extractions. It is much faster as we already have the models loaded
23
+ process:
24
+ # process 1
25
+ - type: locon # locon or lora (locon is lycoris)
26
+ filename: "[name]_64_32.safetensors" # will be put in output folder
27
+ dtype: fp16
28
+ mode: fixed
29
+ linear: 64
30
+ conv: 32
31
+
32
+ # process 2
33
+ - type: locon
34
+ output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
35
+ mode: ratio
36
+ linear: 0.2
37
+ conv: 0.2
38
+
39
+ # process 3
40
+ - type: locon
41
+ filename: "[name]_ratio_02.safetensors"
42
+ mode: quantile
43
+ linear: 0.5
44
+ conv: 0.5
45
+
46
+ # process 4
47
+ - type: lora # traditional lora extraction (lierla) with linear layers only
48
+ filename: "[name]_4.safetensors"
49
+ mode: fixed # fixed, ratio, quantile supported for lora as well
50
+ linear: 4 # lora dim or rank
51
+ # no conv for lora
52
+
53
+ # process 5
54
+ - type: lora
55
+ filename: "[name]_q05.safetensors"
56
+ mode: quantile
57
+ linear: 0.5
58
+
59
+ # you can put any information you want here, and it will be saved in the model
60
+ # the below is an example. I recommend doing trigger words at a minimum
61
+ # in the metadata. The software will include this plus some other information
62
+ meta:
63
+ name: "[name]" # [name] gets replaced with the name above
64
+ description: A short description of your model
65
+ trigger_words:
66
+ - put
67
+ - trigger
68
+ - words
69
+ - here
70
+ version: '0.1'
71
+ creator:
72
+ name: Your Name
73
+ email: your@email.com
74
+ website: https://yourwebsite.com
75
+ any: All meta data above is arbitrary, it can be whatever you want.
ai-toolkit/config/examples/generate.example.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ job: generate # tells the runner what to do
4
+ config:
5
+ name: "generate" # this is not really used anywhere currently but required by runner
6
+ process:
7
+ # process 1
8
+ - type: to_folder # process images to a folder
9
+ output_folder: "output/gen"
10
+ device: cuda:0 # cpu, cuda:0, etc
11
+ generate:
12
+ # these are your defaults you can override most of them with flags
13
+ sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
14
+ width: 1024
15
+ height: 1024
16
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
17
+ seed: -1 # -1 is random
18
+ guidance_scale: 7
19
+ sample_steps: 20
20
+ ext: ".png" # .png, .jpg, .jpeg, .webp
21
+
22
+ # here ate the flags you can use for prompts. Always start with
23
+ # your prompt first then add these flags after. You can use as many
24
+ # like
25
+ # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
26
+ # we will try to support all sd-scripts flags where we can
27
+
28
+ # FROM SD-SCRIPTS
29
+ # --n Treat everything until the next option as a negative prompt.
30
+ # --w Specify the width of the generated image.
31
+ # --h Specify the height of the generated image.
32
+ # --d Specify the seed for the generated image.
33
+ # --l Specify the CFG scale for the generated image.
34
+ # --s Specify the number of steps during generation.
35
+
36
+ # OURS and some QOL additions
37
+ # --p2 Prompt for the second text encoder (SDXL only)
38
+ # --n2 Negative prompt for the second text encoder (SDXL only)
39
+ # --gr Specify the guidance rescale for the generated image (SDXL only)
40
+ # --seed Specify the seed for the generated image same as --d
41
+ # --cfg Specify the CFG scale for the generated image same as --l
42
+ # --steps Specify the number of steps during generation same as --s
43
+
44
+ prompt_file: false # if true a txt file will be created next to images with prompt strings used
45
+ # prompts can also be a path to a text file with one prompt per line
46
+ # prompts: "/path/to/prompts.txt"
47
+ prompts:
48
+ - "photo of batman"
49
+ - "photo of superman"
50
+ - "photo of spiderman"
51
+ - "photo of a superhero --n batman superman spiderman"
52
+
53
+ model:
54
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
55
+ # name_or_path: "runwayml/stable-diffusion-v1-5"
56
+ name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
57
+ is_v2: false # for v2 models
58
+ is_v_pred: false # for v-prediction models (most v2 models)
59
+ is_xl: false # for SDXL models
60
+ dtype: bf16
ai-toolkit/config/examples/mod_lora_scale.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: mod
3
+ config:
4
+ name: name_of_your_model_v1
5
+ process:
6
+ - type: rescale_lora
7
+ # path to your current lora model
8
+ input_path: "/path/to/lora/lora.safetensors"
9
+ # output path for your new lora model, can be the same as input_path to replace
10
+ output_path: "/path/to/lora/output_lora_v1.safetensors"
11
+ # replaces meta with the meta below (plus minimum meta fields)
12
+ # if false, we will leave the meta alone except for updating hashes (sd-script hashes)
13
+ replace_meta: true
14
+ # how to adjust, we can scale the up_down weights or the alpha
15
+ # up_down is the default and probably the best, they will both net the same outputs
16
+ # would only affect rare NaN cases and maybe merging with old merge tools
17
+ scale_target: 'up_down'
18
+ # precision to save, fp16 is the default and standard
19
+ save_dtype: fp16
20
+ # current_weight is the ideal weight you use as a multiplier when using the lora
21
+ # IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
22
+ # you can do negatives here too if you want to flip the lora
23
+ current_weight: 6.0
24
+ # target_weight is the ideal weight you use as a multiplier when using the lora
25
+ # instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
26
+ # we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
27
+ target_weight: 1.0
28
+
29
+ # base model for the lora
30
+ # this is just used to add meta so automatic111 knows which model it is for
31
+ # assume v1.5 if these are not set
32
+ is_xl: false
33
+ is_v2: false
34
+ meta:
35
+ # this is only used if you set replace_meta to true above
36
+ name: "[name]" # [name] gets replaced with the name above
37
+ description: A short description of your lora
38
+ trigger_words:
39
+ - put
40
+ - trigger
41
+ - words
42
+ - here
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
+ email: your@email.com
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
65
+ # place it like "/root/ai-toolkit/FLUX.1-dev"
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
65
+ # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
ai-toolkit/config/examples/train_flex_redux.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_redux_finetune_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ adapter:
14
+ type: "redux"
15
+ # you can finetune an existing adapter or start from scratch. Set to null to start from scratch
16
+ name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors'
17
+ # name_or_path: null
18
+ # image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter
19
+ image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter
20
+ # image_encoder_arch: 'siglip' # for Flux.1
21
+ image_encoder_arch: 'siglip2'
22
+ # You need a control input for each sample. Best to do squares for both images
23
+ test_img_path:
24
+ - "/path/to/x_01.jpg"
25
+ - "/path/to/x_02.jpg"
26
+ - "/path/to/x_03.jpg"
27
+ - "/path/to/x_04.jpg"
28
+ - "/path/to/x_05.jpg"
29
+ - "/path/to/x_06.jpg"
30
+ - "/path/to/x_07.jpg"
31
+ - "/path/to/x_08.jpg"
32
+ - "/path/to/x_09.jpg"
33
+ - "/path/to/x_10.jpg"
34
+ clip_layer: 'last_hidden_state'
35
+ train: true
36
+ save:
37
+ dtype: bf16 # precision to save
38
+ save_every: 250 # save every this many steps
39
+ max_step_saves_to_keep: 4
40
+ datasets:
41
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
42
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
43
+ # images will automatically be resized and bucketed into the resolution specified
44
+ # on windows, escape back slashes with another backslash so
45
+ # "C:\\path\\to\\images\\folder"
46
+ - folder_path: "/path/to/images/folder"
47
+ # clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter)
48
+ # for normal redux, we are just recreating the same image, so you can use the same folder path above
49
+ clip_image_path: "/path/to/control/images/folder"
50
+ caption_ext: "txt"
51
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
52
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
53
+ train:
54
+ # this is what I used for the 24GB card, but feel free to adjust
55
+ # total batch size is 6 here
56
+ batch_size: 3
57
+ gradient_accumulation: 2
58
+
59
+ # captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
60
+ unload_text_encoder: true
61
+
62
+ loss_type: "mse"
63
+ train_unet: true
64
+ train_text_encoder: false
65
+ steps: 4000000 # I set this very high and stop when I like the results
66
+ content_or_style: balanced # content, style, balanced
67
+ gradient_checkpointing: true
68
+ noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a"
69
+ timestep_type: "flux_shift"
70
+ optimizer: "adamw8bit"
71
+ lr: 1e-4
72
+
73
+ # this is for Flex.1, comment this out for FLUX.1-dev
74
+ bypass_guidance_embedding: true
75
+
76
+ dtype: bf16
77
+ ema_config:
78
+ use_ema: true
79
+ ema_decay: 0.99
80
+ model:
81
+ name_or_path: "ostris/Flex.1-alpha"
82
+ is_flux: true
83
+ quantize: true
84
+ text_encoder_bits: 8
85
+ sample:
86
+ sampler: "flowmatch" # must match train.noise_scheduler
87
+ sample_every: 250 # sample every this many steps
88
+ width: 1024
89
+ height: 1024
90
+ # I leave half blank to test prompt and unprompted
91
+ prompts:
92
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
93
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
94
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
95
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
96
+ - "a bear building a log cabin in the snow covered mountains"
97
+ - ""
98
+ - ""
99
+ - ""
100
+ - ""
101
+ - ""
102
+ neg: ""
103
+ seed: 42
104
+ walk_seed: true
105
+ guidance_scale: 4
106
+ sample_steps: 25
107
+ network_multiplier: 1.0
108
+
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
ai-toolkit/config/examples/train_full_fine_tune_flex.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 48GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_flex_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
37
+ bypass_guidance_embedding: true
38
+
39
+ # can be 'sigmoid', 'linear', or 'lognorm_blend'
40
+ timestep_type: 'sigmoid'
41
+
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flex
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adafactor"
49
+ lr: 3e-5
50
+
51
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
52
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
53
+
54
+ # do_paramiter_swapping: true
55
+ # paramiter_swapping_factor: 0.9
56
+
57
+ # uncomment this to skip the pre training sample
58
+ # skip_first_sample: true
59
+ # uncomment to completely disable sampling
60
+ # disable_sampling: true
61
+
62
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
63
+ ema_config:
64
+ use_ema: true
65
+ ema_decay: 0.99
66
+
67
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
68
+ dtype: bf16
69
+ model:
70
+ # huggingface model name or path
71
+ name_or_path: "ostris/Flex.1-alpha"
72
+ is_flux: true # flex is flux architecture
73
+ # full finetuning quantized models is a crapshoot and results in subpar outputs
74
+ # quantize: true
75
+ # you can quantize just the T5 text encoder here to save vram
76
+ quantize_te: true
77
+ # only train the transformer blocks
78
+ only_if_contains:
79
+ - "transformer.transformer_blocks."
80
+ - "transformer.single_transformer_blocks."
81
+ sample:
82
+ sampler: "flowmatch" # must match train.noise_scheduler
83
+ sample_every: 250 # sample every this many steps
84
+ width: 1024
85
+ height: 1024
86
+ prompts:
87
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
88
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
89
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
90
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
91
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
92
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
93
+ - "a bear building a log cabin in the snow covered mountains"
94
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
95
+ - "hipster man with a beard, building a chair, in a wood shop"
96
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
97
+ - "a man holding a sign that says, 'this is a sign'"
98
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
99
+ neg: "" # not used on flex
100
+ seed: 42
101
+ walk_seed: true
102
+ guidance_scale: 4
103
+ sample_steps: 25
104
+ # you can add any additional meta info here. [name] is replaced with config name at top
105
+ meta:
106
+ name: "[name]"
107
+ version: '1.0'
ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 24GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+
37
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
38
+ timestep_type: 'lumina2_shift'
39
+
40
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
41
+ gradient_accumulation: 1
42
+ train_unet: true
43
+ train_text_encoder: false # probably won't work with lumina2
44
+ gradient_checkpointing: true # need the on unless you have a ton of vram
45
+ noise_scheduler: "flowmatch" # for training only
46
+ optimizer: "adafactor"
47
+ lr: 3e-5
48
+
49
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
50
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
51
+
52
+ # do_paramiter_swapping: true
53
+ # paramiter_swapping_factor: 0.9
54
+
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
61
+ # ema_config:
62
+ # use_ema: true
63
+ # ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
70
+ is_lumina2: true # lumina2 architecture
71
+ # you can quantize just the Gemma2 text encoder here to save vram
72
+ quantize_te: true
73
+ sample:
74
+ sampler: "flowmatch" # must match train.noise_scheduler
75
+ sample_every: 250 # sample every this many steps
76
+ width: 1024
77
+ height: 1024
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
82
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
83
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
84
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
85
+ - "a bear building a log cabin in the snow covered mountains"
86
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
87
+ - "hipster man with a beard, building a chair, in a wood shop"
88
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
89
+ - "a man holding a sign that says, 'this is a sign'"
90
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
91
+ neg: ""
92
+ seed: 42
93
+ walk_seed: true
94
+ guidance_scale: 4.0
95
+ sample_steps: 25
96
+ # you can add any additional meta info here. [name] is replaced with config name at top
97
+ meta:
98
+ name: "[name]"
99
+ version: '1.0'
ai-toolkit/config/examples/train_lora_chroma_24gb.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_chroma_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with chroma
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for chroma, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # Download the whichever model you prefer from the Chroma repo
66
+ # https://huggingface.co/lodestones/Chroma/tree/main
67
+ # point to it here.
68
+ # name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
69
+
70
+ # using lodestones/Chroma will automatically use the latest version
71
+ name_or_path: "lodestones/Chroma"
72
+
73
+ # # You can also select a version of Chroma like so
74
+ # name_or_path: "lodestones/Chroma/v28"
75
+
76
+ arch: "chroma"
77
+ quantize: true # run 8bit mixed precision
78
+ sample:
79
+ sampler: "flowmatch" # must match train.noise_scheduler
80
+ sample_every: 250 # sample every this many steps
81
+ width: 1024
82
+ height: 1024
83
+ prompts:
84
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
85
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
86
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
87
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
88
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
89
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
90
+ - "a bear building a log cabin in the snow covered mountains"
91
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
92
+ - "hipster man with a beard, building a chair, in a wood shop"
93
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
94
+ - "a man holding a sign that says, 'this is a sign'"
95
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
96
+ neg: "" # negative prompt, optional
97
+ seed: 42
98
+ walk_seed: true
99
+ guidance_scale: 4
100
+ sample_steps: 25
101
+ # you can add any additional meta info here. [name] is replaced with config name at top
102
+ meta:
103
+ name: "[name]"
104
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flex2_24gb.yaml ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note, Flex2 is a highly experimental WIP model. Finetuning a model with built in controls and inpainting has not
2
+ # been done before, so you will be experimenting with me on how to do it. This is my recommended setup, but this is highly
3
+ # subject to change as we learn more about how Flex2 works.
4
+
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_flex2_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # trigger_word: "p3r5on"
20
+ network:
21
+ type: "lora"
22
+ linear: 32
23
+ linear_alpha: 32
24
+ save:
25
+ dtype: float16 # precision to save
26
+ save_every: 250 # save every this many steps
27
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
28
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
29
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
30
+ # hf_repo_id: your-username/your-model-slug
31
+ # hf_private: true #whether the repo is private or public
32
+ datasets:
33
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
34
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
35
+ # images will automatically be resized and bucketed into the resolution specified
36
+ # on windows, escape back slashes with another backslash so
37
+ # "C:\\path\\to\\images\\folder"
38
+ - folder_path: "/path/to/images/folder"
39
+ # Flex2 is trained with controls and inpainting. If you want the model to truely understand how the
40
+ # controls function with your dataset, it is a good idea to keep doing controls during training.
41
+ # this will automatically generate the controls for you before training. The current script is not
42
+ # fully optimized so this could be rather slow for large datasets, but it caches them to disk so it
43
+ # only needs to be done once. If you want to skip this step, you can set the controls to [] and it will
44
+ controls:
45
+ - "depth"
46
+ - "line"
47
+ - "pose"
48
+ - "inpaint"
49
+
50
+ # you can make custom inpainting images as well. These images must be webp or png format with an alpha.
51
+ # just erase the part of the image you want to inpaint and save it as a webp or png. Again, erase your
52
+ # train target. So the person if training a person. The automatic controls above with inpaint will
53
+ # just run a background remover mask and erase the foreground, which works well for subjects.
54
+
55
+ # inpaint_path: "/my/impaint/images"
56
+
57
+ # you can also specify existing control image pairs. It can handle multiple groups and will randomly
58
+ # select one for each step.
59
+
60
+ # control_path:
61
+ # - "/my/custom/control/images"
62
+ # - "/my/custom/control/images2"
63
+
64
+ caption_ext: "txt"
65
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
66
+ resolution: [ 512, 768, 1024 ] # flex2 enjoys multiple resolutions
67
+ train:
68
+ batch_size: 1
69
+ # IMPORTANT! For Flex2, you must bypass the guidance embedder during training
70
+ bypass_guidance_embedding: true
71
+
72
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
73
+ gradient_accumulation: 1
74
+ train_unet: true
75
+ train_text_encoder: false # probably won't work with flex2
76
+ gradient_checkpointing: true # need the on unless you have a ton of vram
77
+ noise_scheduler: "flowmatch" # for training only
78
+ # shift works well for training fast and learning composition and style.
79
+ # for just subject, you may want to change this to sigmoid
80
+ timestep_type: 'shift' # 'linear', 'sigmoid', 'shift'
81
+ optimizer: "adamw8bit"
82
+ lr: 1e-4
83
+
84
+ optimizer_params:
85
+ weight_decay: 1e-5
86
+ # uncomment this to skip the pre training sample
87
+ # skip_first_sample: true
88
+ # uncomment to completely disable sampling
89
+ # disable_sampling: true
90
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
91
+ # linear_timesteps: true
92
+
93
+ # ema will smooth out learning, but could slow it down. Defaults off
94
+ ema_config:
95
+ use_ema: false
96
+ ema_decay: 0.99
97
+
98
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
99
+ dtype: bf16
100
+ model:
101
+ # huggingface model name or path
102
+ name_or_path: "ostris/Flex.2-preview"
103
+ arch: "flex2"
104
+ quantize: true # run 8bit mixed precision
105
+ quantize_te: true
106
+
107
+ # you can pass special training infor for controls to the model here
108
+ # percentages are decimal based so 0.0 is 0% and 1.0 is 100% of the time.
109
+ model_kwargs:
110
+ # inverts the inpainting mask, good to learn outpainting as well, recommended 0.0 for characters
111
+ invert_inpaint_mask_chance: 0.5
112
+ # this will do a normal t2i training step without inpaint when dropped out. REcommended if you want
113
+ # your lora to be able to inference with and without inpainting.
114
+ inpaint_dropout: 0.5
115
+ # randomly drops out the control image. Dropout recvommended if your want it to work without controls as well.
116
+ control_dropout: 0.5
117
+ # does a random inpaint blob. Usually a good idea to keep. Without it, the model will learn to always 100%
118
+ # fill the inpaint area with your subject. This is not always a good thing.
119
+ inpaint_random_chance: 0.5
120
+ # generates random inpaint blobs if you did not provide an inpaint image for your dataset. Inpaint breaks down fast
121
+ # if you are not training with it. Controls are a little more robust and can be left out,
122
+ # but when in doubt, always leave this on
123
+ do_random_inpainting: false
124
+ # does random blurring of the inpaint mask. Helps prevent weird edge artifacts for real workd inpainting. Leave on.
125
+ random_blur_mask: true
126
+ # applies a small amount of random dialition and restriction to the inpaint mask. Helps with edge artifacts.
127
+ # Leave on.
128
+ random_dialate_mask: true
129
+ sample:
130
+ sampler: "flowmatch" # must match train.noise_scheduler
131
+ sample_every: 250 # sample every this many steps
132
+ width: 1024
133
+ height: 1024
134
+ prompts:
135
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
136
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
137
+
138
+ # you can use a single inpaint or single control image on your samples.
139
+ # for controls, the ctrl_idx is 1, the images can be any name and image format.
140
+ # use either a pose/line/depth image or whatever you are training with. An example is
141
+ # - "photo of [trigger] --ctrl_idx 1 --ctrl_img /path/to/control/image.jpg"
142
+
143
+ # for an inpainting image, it must be png/webp. Erase the part of the image you want to inpaint
144
+ # IMPORTANT! the inpaint images must be ctrl_idx 0 and have .inpaint.{ext} in the name for this to work right.
145
+ # - "photo of [trigger] --ctrl_idx 0 --ctrl_img /path/to/inpaint/image.inpaint.png"
146
+
147
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
148
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
149
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
150
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
151
+ - "a bear building a log cabin in the snow covered mountains"
152
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
153
+ - "hipster man with a beard, building a chair, in a wood shop"
154
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
155
+ - "a man holding a sign that says, 'this is a sign'"
156
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
157
+ neg: "" # not used on flex2
158
+ seed: 42
159
+ walk_seed: true
160
+ guidance_scale: 4
161
+ sample_steps: 25
162
+ # you can add any additional meta info here. [name] is replaced with config name at top
163
+ meta:
164
+ name: "[name]"
165
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flex_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
43
+ bypass_guidance_embedding: true
44
+
45
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
46
+ gradient_accumulation: 1
47
+ train_unet: true
48
+ train_text_encoder: false # probably won't work with flex
49
+ gradient_checkpointing: true # need the on unless you have a ton of vram
50
+ noise_scheduler: "flowmatch" # for training only
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ # uncomment this to skip the pre training sample
54
+ # skip_first_sample: true
55
+ # uncomment to completely disable sampling
56
+ # disable_sampling: true
57
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
58
+ # linear_timesteps: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
61
+ ema_config:
62
+ use_ema: true
63
+ ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "ostris/Flex.1-alpha"
70
+ is_flux: true
71
+ quantize: true # run 8bit mixed precision
72
+ quantize_kwargs:
73
+ exclude:
74
+ - "*time_text_embed*" # exclude the time text embedder from quantization
75
+ sample:
76
+ sampler: "flowmatch" # must match train.noise_scheduler
77
+ sample_every: 250 # sample every this many steps
78
+ width: 1024
79
+ height: 1024
80
+ prompts:
81
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
82
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
83
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
84
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
85
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
86
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
87
+ - "a bear building a log cabin in the snow covered mountains"
88
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
89
+ - "hipster man with a beard, building a chair, in a wood shop"
90
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
91
+ - "a man holding a sign that says, 'this is a sign'"
92
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
93
+ neg: "" # not used on flex
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 4
97
+ sample_steps: 25
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_kontext_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ # control path is the input images for kontext for a paired dataset. These are the source images you want to change.
36
+ # You can comment this out and only use normal images if you don't have a paired dataset.
37
+ # Control images need to match the filenames on the folder path but in
38
+ # a different folder. These do not need captions.
39
+ control_path: "/path/to/control/folder"
40
+ caption_ext: "txt"
41
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
42
+ shuffle_tokens: false # shuffle caption order, split by commas
43
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
44
+ # Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram.
45
+ resolution: [ 512, 768 ] # flux enjoys multiple resolutions
46
+ # resolution: [ 512, 768, 1024 ]
47
+ train:
48
+ batch_size: 1
49
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
50
+ gradient_accumulation_steps: 1
51
+ train_unet: true
52
+ train_text_encoder: false # probably won't work with flux
53
+ gradient_checkpointing: true # need the on unless you have a ton of vram
54
+ noise_scheduler: "flowmatch" # for training only
55
+ optimizer: "adamw8bit"
56
+ lr: 1e-4
57
+ timestep_type: "weighted" # sigmoid, linear, or weighted.
58
+ # uncomment this to skip the pre training sample
59
+ # skip_first_sample: true
60
+ # uncomment to completely disable sampling
61
+ # disable_sampling: true
62
+
63
+ # ema will smooth out learning, but could slow it down.
64
+
65
+ # ema_config:
66
+ # use_ema: true
67
+ # ema_decay: 0.99
68
+
69
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
70
+ dtype: bf16
71
+ model:
72
+ # huggingface model name or path. This model is gated.
73
+ # visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions
74
+ # and then you can use this model.
75
+ name_or_path: "black-forest-labs/FLUX.1-Kontext-dev"
76
+ arch: "flux_kontext"
77
+ quantize: true # run 8bit mixed precision
78
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
79
+ sample:
80
+ sampler: "flowmatch" # must match train.noise_scheduler
81
+ sample_every: 250 # sample every this many steps
82
+ width: 1024
83
+ height: 1024
84
+ prompts:
85
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
86
+ # the --ctrl_img path is the one loaded to apply the kontext editing to
87
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
88
+ - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
89
+ - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
90
+ - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
91
+ - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
92
+ - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
93
+ - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
94
+ - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
95
+ - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
96
+ - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
97
+ - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
98
+ neg: "" # not used on flux
99
+ seed: 42
100
+ walk_seed: true
101
+ guidance_scale: 4
102
+ sample_steps: 20
103
+ # you can add any additional meta info here. [name] is replaced with config name at top
104
+ meta:
105
+ name: "[name]"
106
+ version: '1.0'
ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new bell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
ai-toolkit/config/examples/train_lora_hidream_48.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
2
+ # It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
3
+ # I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
4
+ # HiDream has a mixture of experts that may take special training considerations that I do not
5
+ # have implemented properly. The current implementation seems to work well for LoRA training, but
6
+ # may not be effective for longer training runs. The implementation could change in future updates
7
+ # so your results may vary when this happens.
8
+
9
+ ---
10
+ job: extension
11
+ config:
12
+ # this name will be the folder and filename name
13
+ name: "my_first_hidream_lora_v1"
14
+ process:
15
+ - type: 'sd_trainer'
16
+ # root folder to save training sessions/samples/weights
17
+ training_folder: "output"
18
+ # uncomment to see performance stats in the terminal every N steps
19
+ # performance_log_every: 1000
20
+ device: cuda:0
21
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
22
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
23
+ # trigger_word: "p3r5on"
24
+ network:
25
+ type: "lora"
26
+ linear: 32
27
+ linear_alpha: 32
28
+ network_kwargs:
29
+ # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
30
+ # proper training of it is not fully implemented
31
+ ignore_if_contains:
32
+ - "ff_i.experts"
33
+ - "ff_i.gate"
34
+ save:
35
+ dtype: bfloat16 # precision to save
36
+ save_every: 250 # save every this many steps
37
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
38
+ datasets:
39
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
40
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
41
+ # images will automatically be resized and bucketed into the resolution specified
42
+ # on windows, escape back slashes with another backslash so
43
+ # "C:\\path\\to\\images\\folder"
44
+ - folder_path: "/path/to/images/folder"
45
+ caption_ext: "txt"
46
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
47
+ resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
48
+ train:
49
+ batch_size: 1
50
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
51
+ gradient_accumulation_steps: 1
52
+ train_unet: true
53
+ train_text_encoder: false # wont work with hidream
54
+ gradient_checkpointing: true # need the on unless you have a ton of vram
55
+ noise_scheduler: "flowmatch" # for training only
56
+ timestep_type: shift # sigmoid, shift, linear
57
+ optimizer: "adamw8bit"
58
+ lr: 2e-4
59
+ # uncomment this to skip the pre training sample
60
+ # skip_first_sample: true
61
+ # uncomment to completely disable sampling
62
+ # disable_sampling: true
63
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
64
+ # linear_timesteps: true
65
+
66
+ # ema will smooth out learning, but could slow it down. Defaults off
67
+ ema_config:
68
+ use_ema: false
69
+ ema_decay: 0.99
70
+
71
+ # will probably need this if gpu supports it for hidream, other dtypes may not work correctly
72
+ dtype: bf16
73
+ model:
74
+ # the transformer will get grabbed from this hf repo
75
+ # warning ONLY train on Full. The dev and fast models are distilled and will break
76
+ name_or_path: "HiDream-ai/HiDream-I1-Full"
77
+ # the extras will be grabbed from this hf repo. (text encoder, vae)
78
+ extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
79
+ arch: "hidream"
80
+ # both need to be quantized to train on 48GB currently
81
+ quantize: true
82
+ quantize_te: true
83
+ model_kwargs:
84
+ # llama is a gated model, It defaults to unsloth version, but you can set the llama path here
85
+ llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
86
+ sample:
87
+ sampler: "flowmatch" # must match train.noise_scheduler
88
+ sample_every: 250 # sample every this many steps
89
+ width: 1024
90
+ height: 1024
91
+ prompts:
92
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
93
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
94
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
95
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
96
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
97
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
98
+ - "a bear building a log cabin in the snow covered mountains"
99
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
100
+ - "hipster man with a beard, building a chair, in a wood shop"
101
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
102
+ - "a man holding a sign that says, 'this is a sign'"
103
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
104
+ neg: ""
105
+ seed: 42
106
+ walk_seed: true
107
+ guidance_scale: 4
108
+ sample_steps: 25
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
ai-toolkit/config/examples/train_lora_lumina.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 20GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: bf16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
25
+ save_format: 'diffusers' # 'diffusers'
26
+ datasets:
27
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
28
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
29
+ # images will automatically be resized and bucketed into the resolution specified
30
+ # on windows, escape back slashes with another backslash so
31
+ # "C:\\path\\to\\images\\folder"
32
+ - folder_path: "/path/to/images/folder"
33
+ caption_ext: "txt"
34
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
35
+ shuffle_tokens: false # shuffle caption order, split by commas
36
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
37
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
38
+ train:
39
+ batch_size: 1
40
+
41
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
42
+ timestep_type: 'lumina2_shift'
43
+
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with lumina2
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
67
+ is_lumina2: true # lumina2 architecture
68
+ # you can quantize just the Gemma2 text encoder here to save vram
69
+ quantize_te: true
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: ""
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4.0
92
+ sample_steps: 25
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_omnigen2_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with omnigen2
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ timestep_type: 'sigmoid' # sigmoid, linear, shift
51
+ # uncomment this to skip the pre training sample
52
+ # skip_first_sample: true
53
+ # uncomment to completely disable sampling
54
+ # disable_sampling: true
55
+
56
+ # ema will smooth out learning, but could slow it down.
57
+ # ema_config:
58
+ # use_ema: true
59
+ # ema_decay: 0.99
60
+
61
+ # will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
62
+ dtype: bf16
63
+ model:
64
+ name_or_path: "OmniGen2/OmniGen2
65
+ arch: "omnigen2"
66
+ quantize_te: true # quantize_only te
67
+ # quantize: true # quantize transformer
68
+ sample:
69
+ sampler: "flowmatch" # must match train.noise_scheduler
70
+ sample_every: 250 # sample every this many steps
71
+ width: 1024
72
+ height: 1024
73
+ prompts:
74
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
75
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
76
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
77
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
78
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
79
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
80
+ - "a bear building a log cabin in the snow covered mountains"
81
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
82
+ - "hipster man with a beard, building a chair, in a wood shop"
83
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
84
+ - "a man holding a sign that says, 'this is a sign'"
85
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
86
+ neg: "" # negative prompt, optional
87
+ seed: 42
88
+ walk_seed: true
89
+ guidance_scale: 4
90
+ sample_steps: 25
91
+ # you can add any additional meta info here. [name] is replaced with config name at top
92
+ meta:
93
+ name: "[name]"
94
+ version: '1.0'
ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_qwen_image_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # Trigger words will not work when caching text embeddings
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ datasets:
26
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
27
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
28
+ # images will automatically be resized and bucketed into the resolution specified
29
+ # on windows, escape back slashes with another backslash so
30
+ # "C:\\path\\to\\images\\folder"
31
+ - folder_path: "/path/to/images/folder"
32
+ caption_ext: "txt"
33
+ # default_caption: "a person" # if caching text embeddings, if you dont have captions, this will get cached
34
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
35
+ shuffle_tokens: false # shuffle caption order, split by commas
36
+ cache_latents_to_disk: true # leave this true unless you have a large dataset
37
+ # if you OOM, 1024 may be too much, but should work
38
+ resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
39
+ train:
40
+ batch_size: 1
41
+ # caching text embeddings is required for 24GB
42
+ cache_text_embeddings: true
43
+
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with qwen image
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+ dtype: bf16
57
+ model:
58
+ # huggingface model name or path
59
+ name_or_path: "Qwen/Qwen-Image"
60
+ arch: "qwen_image"
61
+ quantize: true
62
+ # qtype_te: "qfloat8" Default float8 qquantization
63
+ # to use the ARA use the | pipe to point to hf path, or a local path if you have one.
64
+ # 3bit is required for 24GB
65
+ qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors"
66
+ quantize_te: true
67
+ qtype_te: "qfloat8"
68
+ low_vram: true
69
+ sample:
70
+ sampler: "flowmatch" # must match train.noise_scheduler
71
+ sample_every: 250 # sample every this many steps
72
+ width: 1024
73
+ height: 1024
74
+ prompts:
75
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
76
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
77
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
78
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
79
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
80
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
81
+ - "a bear building a log cabin in the snow covered mountains"
82
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
83
+ - "hipster man with a beard, building a chair, in a wood shop"
84
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
85
+ - "a man holding a sign that says, 'this is a sign'"
86
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
87
+ neg: ""
88
+ seed: 42
89
+ walk_seed: true
90
+ guidance_scale: 3
91
+ sample_steps: 25
92
+ # you can add any additional meta info here. [name] is replaced with config name at top
93
+ meta:
94
+ name: "[name]"
95
+ version: '1.0'
ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_qwen_image_edit_2509_lora_v1"
6
+ process:
7
+ - type: 'diffusion_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ network:
14
+ type: "lora"
15
+ linear: 16
16
+ linear_alpha: 16
17
+ save:
18
+ dtype: float16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
21
+ datasets:
22
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
23
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
24
+ # images will automatically be resized and bucketed into the resolution specified
25
+ # on windows, escape back slashes with another backslash so
26
+ # "C:\\path\\to\\images\\folder"
27
+ - folder_path: "/path/to/images/folder"
28
+ # can do up to 3 control image folders, file names must match target file names, but aspect/size can be different
29
+ control_path:
30
+ - "/path/to/control/images/folder1"
31
+ - "/path/to/control/images/folder2"
32
+ - "/path/to/control/images/folder3"
33
+ caption_ext: "txt"
34
+ # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached
35
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
36
+ resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
37
+ # a trigger word that can be cached with the text embeddings
38
+ # trigger_word: "optional trigger word"
39
+ train:
40
+ batch_size: 1
41
+ # caching text embeddings is required for 32GB
42
+ cache_text_embeddings: true
43
+ # unload_text_encoder: true
44
+
45
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
46
+ gradient_accumulation: 1
47
+ timestep_type: "weighted"
48
+ train_unet: true
49
+ train_text_encoder: false # probably won't work with qwen image
50
+ gradient_checkpointing: true # need the on unless you have a ton of vram
51
+ noise_scheduler: "flowmatch" # for training only
52
+ optimizer: "adamw8bit"
53
+ lr: 1e-4
54
+ # uncomment this to skip the pre training sample
55
+ # skip_first_sample: true
56
+ # uncomment to completely disable sampling
57
+ # disable_sampling: true
58
+ dtype: bf16
59
+ model:
60
+ # huggingface model name or path
61
+ name_or_path: "Qwen/Qwen-Image-Edit-2509"
62
+ arch: "qwen_image_edit_plus"
63
+ quantize: true
64
+ # to use the ARA use the | pipe to point to hf path, or a local path if you have one.
65
+ # 3bit is required for 32GB
66
+ qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors"
67
+ quantize_te: true
68
+ qtype_te: "qfloat8"
69
+ low_vram: true
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ # you can provide up to 3 control images here
76
+ samples:
77
+ - prompt: "Do whatever with Image1 and Image2"
78
+ ctrl_img_1: "/path/to/image1.png"
79
+ ctrl_img_2: "/path/to/image2.png"
80
+ # ctrl_img_3: "/path/to/image3.png"
81
+ - prompt: "Do whatever with Image1 and Image2"
82
+ ctrl_img_1: "/path/to/image1.png"
83
+ ctrl_img_2: "/path/to/image2.png"
84
+ # ctrl_img_3: "/path/to/image3.png"
85
+ - prompt: "Do whatever with Image1 and Image2"
86
+ ctrl_img_1: "/path/to/image1.png"
87
+ ctrl_img_2: "/path/to/image2.png"
88
+ # ctrl_img_3: "/path/to/image3.png"
89
+ - prompt: "Do whatever with Image1 and Image2"
90
+ ctrl_img_1: "/path/to/image1.png"
91
+ ctrl_img_2: "/path/to/image2.png"
92
+ # ctrl_img_3: "/path/to/image3.png"
93
+ - prompt: "Do whatever with Image1 and Image2"
94
+ ctrl_img_1: "/path/to/image1.png"
95
+ ctrl_img_2: "/path/to/image2.png"
96
+ # ctrl_img_3: "/path/to/image3.png"
97
+ neg: ""
98
+ seed: 42
99
+ walk_seed: true
100
+ guidance_scale: 3
101
+ sample_steps: 25
102
+ # you can add any additional meta info here. [name] is replaced with config name at top
103
+ meta:
104
+ name: "[name]"
105
+ version: '1.0'
ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_qwen_image_edit_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # Trigger words will not work when caching text embeddings
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ datasets:
26
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
27
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
28
+ # images will automatically be resized and bucketed into the resolution specified
29
+ # on windows, escape back slashes with another backslash so
30
+ # "C:\\path\\to\\images\\folder"
31
+ - folder_path: "/path/to/images/folder"
32
+ control_path: "/path/to/control/images/folder"
33
+ caption_ext: "txt"
34
+ # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached
35
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
36
+ resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ # caching text embeddings is required for 32GB
40
+ cache_text_embeddings: true
41
+
42
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ timestep_type: "weighted"
45
+ train_unet: true
46
+ train_text_encoder: false # probably won't work with qwen image
47
+ gradient_checkpointing: true # need the on unless you have a ton of vram
48
+ noise_scheduler: "flowmatch" # for training only
49
+ optimizer: "adamw8bit"
50
+ lr: 1e-4
51
+ # uncomment this to skip the pre training sample
52
+ # skip_first_sample: true
53
+ # uncomment to completely disable sampling
54
+ # disable_sampling: true
55
+ dtype: bf16
56
+ model:
57
+ # huggingface model name or path
58
+ name_or_path: "Qwen/Qwen-Image-Edit"
59
+ arch: "qwen_image_edit"
60
+ quantize: true
61
+ # qtype_te: "qfloat8" Default float8 qquantization
62
+ # to use the ARA use the | pipe to point to hf path, or a local path if you have one.
63
+ # 3bit is required for 32GB
64
+ qtype: "uint3|qwen_image_edit_torchao_uint3.safetensors"
65
+ quantize_te: true
66
+ qtype_te: "qfloat8"
67
+ low_vram: true
68
+ sample:
69
+ sampler: "flowmatch" # must match train.noise_scheduler
70
+ sample_every: 250 # sample every this many steps
71
+ width: 1024
72
+ height: 1024
73
+ samples:
74
+ - prompt: "do the thing to it"
75
+ ctrl_img: "/path/to/control/image.jpg"
76
+ - prompt: "do the thing to it"
77
+ ctrl_img: "/path/to/control/image.jpg"
78
+ - prompt: "do the thing to it"
79
+ ctrl_img: "/path/to/control/image.jpg"
80
+ - prompt: "do the thing to it"
81
+ ctrl_img: "/path/to/control/image.jpg"
82
+ - prompt: "do the thing to it"
83
+ ctrl_img: "/path/to/control/image.jpg"
84
+ - prompt: "do the thing to it"
85
+ ctrl_img: "/path/to/control/image.jpg"
86
+ - prompt: "do the thing to it"
87
+ ctrl_img: "/path/to/control/image.jpg"
88
+ - prompt: "do the thing to it"
89
+ ctrl_img: "/path/to/control/image.jpg"
90
+ - prompt: "do the thing to it"
91
+ ctrl_img: "/path/to/control/image.jpg"
92
+ - prompt: "do the thing to it"
93
+ ctrl_img: "/path/to/control/image.jpg"
94
+ neg: ""
95
+ seed: 42
96
+ walk_seed: true
97
+ guidance_scale: 3
98
+ sample_steps: 25
99
+ # you can add any additional meta info here. [name] is replaced with config name at top
100
+ meta:
101
+ name: "[name]"
102
+ version: '1.0'
ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_sd3l_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
26
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
27
+ # hf_repo_id: your-username/your-model-slug
28
+ # hf_private: true #whether the repo is private or public
29
+ datasets:
30
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
31
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
32
+ # images will automatically be resized and bucketed into the resolution specified
33
+ # on windows, escape back slashes with another backslash so
34
+ # "C:\\path\\to\\images\\folder"
35
+ - folder_path: "/path/to/images/folder"
36
+ caption_ext: "txt"
37
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
38
+ shuffle_tokens: false # shuffle caption order, split by commas
39
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
40
+ resolution: [ 1024 ]
41
+ train:
42
+ batch_size: 1
43
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
44
+ gradient_accumulation_steps: 1
45
+ train_unet: true
46
+ train_text_encoder: false # May not fully work with SD3 yet
47
+ gradient_checkpointing: true # need the on unless you have a ton of vram
48
+ noise_scheduler: "flowmatch"
49
+ timestep_type: "linear" # linear or sigmoid
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
57
+ # linear_timesteps: true
58
+
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+
64
+ # will probably need this if gpu supports it for sd3, other dtypes may not work correctly
65
+ dtype: bf16
66
+ model:
67
+ # huggingface model name or path
68
+ name_or_path: "stabilityai/stable-diffusion-3.5-large"
69
+ is_v3: true
70
+ quantize: true # run 8bit mixed precision
71
+ sample:
72
+ sampler: "flowmatch" # must match train.noise_scheduler
73
+ sample_every: 250 # sample every this many steps
74
+ width: 1024
75
+ height: 1024
76
+ prompts:
77
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
78
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
79
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
80
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
81
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
82
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
83
+ - "a bear building a log cabin in the snow covered mountains"
84
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
85
+ - "hipster man with a beard, building a chair, in a wood shop"
86
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
87
+ - "a man holding a sign that says, 'this is a sign'"
88
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
89
+ neg: ""
90
+ seed: 42
91
+ walk_seed: true
92
+ guidance_scale: 4
93
+ sample_steps: 25
94
+ # you can add any additional meta info here. [name] is replaced with config name at top
95
+ meta:
96
+ name: "[name]"
97
+ version: '1.0'
ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot
2
+ # support keeping the text encoder on GPU while training with 24GB, so it is only good
3
+ # for training on a single prompt, for example a person with a trigger word.
4
+ # to train on captions, you need more vran for now.
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_wan21_14b_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # this is probably needed for 24GB cards when offloading TE to CPU
20
+ trigger_word: "p3r5on"
21
+ network:
22
+ type: "lora"
23
+ linear: 32
24
+ linear_alpha: 32
25
+ save:
26
+ dtype: float16 # precision to save
27
+ save_every: 250 # save every this many steps
28
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
29
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
30
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
31
+ # hf_repo_id: your-username/your-model-slug
32
+ # hf_private: true #whether the repo is private or public
33
+ datasets:
34
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
35
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
36
+ # images will automatically be resized and bucketed into the resolution specified
37
+ # on windows, escape back slashes with another backslash so
38
+ # "C:\\path\\to\\images\\folder"
39
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
40
+ # it works well for characters, but not as well for "actions"
41
+ - folder_path: "/path/to/images/folder"
42
+ caption_ext: "txt"
43
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
44
+ shuffle_tokens: false # shuffle caption order, split by commas
45
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
46
+ resolution: [ 632 ] # will be around 480p
47
+ train:
48
+ batch_size: 1
49
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
50
+ gradient_accumulation: 1
51
+ train_unet: true
52
+ train_text_encoder: false # probably won't work with wan
53
+ gradient_checkpointing: true # need the on unless you have a ton of vram
54
+ noise_scheduler: "flowmatch" # for training only
55
+ timestep_type: 'sigmoid'
56
+ optimizer: "adamw8bit"
57
+ lr: 1e-4
58
+ optimizer_params:
59
+ weight_decay: 1e-4
60
+ # uncomment this to skip the pre training sample
61
+ # skip_first_sample: true
62
+ # uncomment to completely disable sampling
63
+ # disable_sampling: true
64
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
65
+ ema_config:
66
+ use_ema: true
67
+ ema_decay: 0.99
68
+ dtype: bf16
69
+ # required for 24GB cards
70
+ # this will encode your trigger word and use those embeddings for every image in the dataset
71
+ unload_text_encoder: true
72
+ model:
73
+ # huggingface model name or path
74
+ name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
75
+ arch: 'wan21'
76
+ # these settings will save as much vram as possible
77
+ quantize: true
78
+ quantize_te: true
79
+ low_vram: true
80
+ sample:
81
+ sampler: "flowmatch"
82
+ sample_every: 250 # sample every this many steps
83
+ width: 832
84
+ height: 480
85
+ num_frames: 40
86
+ fps: 15
87
+ # samples take a long time. so use them sparingly
88
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
89
+ prompts:
90
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
91
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
92
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
93
+ neg: ""
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 5
97
+ sample_steps: 30
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_wan21_1b_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 32
19
+ linear_alpha: 32
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
35
+ # it works well for characters, but not as well for "actions"
36
+ - folder_path: "/path/to/images/folder"
37
+ caption_ext: "txt"
38
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
39
+ shuffle_tokens: false # shuffle caption order, split by commas
40
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
41
+ resolution: [ 632 ] # will be around 480p
42
+ train:
43
+ batch_size: 1
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with wan
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ timestep_type: 'sigmoid'
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ optimizer_params:
54
+ weight_decay: 1e-4
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
67
+ arch: 'wan21'
68
+ quantize_te: true # saves vram
69
+ sample:
70
+ sampler: "flowmatch"
71
+ sample_every: 250 # sample every this many steps
72
+ width: 832
73
+ height: 480
74
+ num_frames: 40
75
+ fps: 15
76
+ # samples take a long time. so use them sparingly
77
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
82
+ neg: ""
83
+ seed: 42
84
+ walk_seed: true
85
+ guidance_scale: 5
86
+ sample_steps: 30
87
+ # you can add any additional meta info here. [name] is replaced with config name at top
88
+ meta:
89
+ name: "[name]"
90
+ version: '1.0'
ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this example focuses mainly for training Wan2.2 14b on images. It will work for video as well by increasing
2
+ # the number of frames in the dataset and samples. Training on and generating video is very VRAM intensive.
3
+ ---
4
+ job: extension
5
+ config:
6
+ # this name will be the folder and filename name
7
+ name: "my_first_wan22_14b_lora_v1"
8
+ process:
9
+ - type: 'sd_trainer'
10
+ # root folder to save training sessions/samples/weights
11
+ training_folder: "output"
12
+ # uncomment to see performance stats in the terminal every N steps
13
+ # performance_log_every: 1000
14
+ device: cuda:0
15
+ # Use a trigger word if train.unload_text_encoder is true, however, if caching text embeddings, do not use a trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 32
20
+ linear_alpha: 32
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ datasets:
26
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
27
+ # for instance image2.jpg and image2.txt.
28
+ # "C:\\path\\to\\images\\folder"
29
+ - folder_path: "/path/to/images/or/video/folder"
30
+ caption_ext: "txt"
31
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
32
+ # number of frames to extract from your video. It will automatically extract them evenly spaced
33
+ # set to 1 frame for images
34
+ num_frames: 1
35
+ resolution: [ 512, 768, 1024]
36
+ train:
37
+ batch_size: 1
38
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
39
+ gradient_accumulation: 1
40
+ train_unet: true
41
+ train_text_encoder: false # probably won't work with wan
42
+ gradient_checkpointing: true # need the on unless you have a ton of vram
43
+ noise_scheduler: "flowmatch" # for training only
44
+ timestep_type: 'linear'
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ optimizer_params:
48
+ weight_decay: 1e-4
49
+ # uncomment this to skip the pre training sample
50
+ # skip_first_sample: true
51
+ # uncomment to completely disable sampling
52
+ # disable_sampling: true
53
+ dtype: bf16
54
+
55
+ # IMPORTANT: this is for Wan 2.2 MOE. It will switch training one stage or the other every this many steps
56
+ switch_boundary_every: 10
57
+
58
+ # required for 24GB cards. You must do either unload_text_encoder or cache_text_embeddings but not both
59
+
60
+ # this will encode your trigger word and use those embeddings for every image in the dataset, captions will be ignored
61
+ # unload_text_encoder: true
62
+
63
+ # this will cache all captions in your dataset.
64
+ cache_text_embeddings: true
65
+
66
+ model:
67
+ # huggingface model name or path, this one if bf16, vs the float32 of the official repo
68
+ name_or_path: "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16"
69
+ arch: 'wan22_14b'
70
+ quantize: true
71
+ # This will pull and use a custom Accuracy Recovery Adapter to train at 4bit
72
+ qtype: "uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors"
73
+ quantize_te: true
74
+ qtype_te: "qfloat8"
75
+ low_vram: true
76
+ model_kwargs:
77
+ # you can train high noise, low noise, or both. With low vram it will automatically unload the one not being trained.
78
+ train_high_noise: true
79
+ train_low_noise: true
80
+ sample:
81
+ sampler: "flowmatch"
82
+ sample_every: 250 # sample every this many steps
83
+ width: 1024
84
+ height: 1024
85
+ # set to 1 for images
86
+ num_frames: 1
87
+ fps: 16
88
+ # samples take a long time. so use them sparingly
89
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
90
+ prompts:
91
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
92
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
93
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
94
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
95
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
96
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
97
+ - "a bear building a log cabin in the snow covered mountains"
98
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
99
+ - "hipster man with a beard, building a chair, in a wood shop"
100
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
101
+ - "a man holding a sign that says, 'this is a sign'"
102
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
103
+ neg: ""
104
+ seed: 42
105
+ walk_seed: true
106
+ guidance_scale: 3.5
107
+ sample_steps: 25
108
+ # you can add any additional meta info here. [name] is replaced with config name at top
109
+ meta:
110
+ name: "[name]"
111
+ version: '1.0'
ai-toolkit/config/examples/train_slider.example.yml ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to write
4
+ # Plus it has comments which is nice for documentation
5
+ # This is the config I use on my sliders, It is solid and tested
6
+ job: train
7
+ config:
8
+ # the name will be used to create a folder in the output folder
9
+ # it will also replace any [name] token in the rest of this config
10
+ name: detail_slider_v1
11
+ # folder will be created with name above in folder below
12
+ # it can be relative to the project root or absolute
13
+ training_folder: "output/LoRA"
14
+ device: cuda:0 # cpu, cuda:0, etc
15
+ # for tensorboard logging, we will make a subfolder for this job
16
+ log_dir: "output/.tensorboard"
17
+ # you can stack processes for other jobs, It is not tested with sliders though
18
+ # just use one for now
19
+ process:
20
+ - type: slider # tells runner to run the slider process
21
+ # network is the LoRA network for a slider, I recommend to leave this be
22
+ network:
23
+ # network type lierla is traditional LoRA that works everywhere, only linear layers
24
+ type: "lierla"
25
+ # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
26
+ linear: 8
27
+ linear_alpha: 4 # Do about half of rank
28
+ # training config
29
+ train:
30
+ # this is also used in sampling. Stick with ddpm unless you know what you are doing
31
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
32
+ # how many steps to train. More is not always better. I rarely go over 1000
33
+ steps: 500
34
+ # I have had good results with 4e-4 to 1e-4 at 500 steps
35
+ lr: 2e-4
36
+ # enables gradient checkpoint, saves vram, leave it on
37
+ gradient_checkpointing: true
38
+ # train the unet. I recommend leaving this true
39
+ train_unet: true
40
+ # train the text encoder. I don't recommend this unless you have a special use case
41
+ # for sliders we are adjusting representation of the concept (unet),
42
+ # not the description of it (text encoder)
43
+ train_text_encoder: false
44
+ # same as from sd-scripts, not fully tested but should speed up training
45
+ min_snr_gamma: 5.0
46
+ # just leave unless you know what you are doing
47
+ # also supports "dadaptation" but set lr to 1 if you use that,
48
+ # but it learns too fast and I don't recommend it
49
+ optimizer: "adamw"
50
+ # only constant for now
51
+ lr_scheduler: "constant"
52
+ # we randomly denoise random num of steps form 1 to this number
53
+ # while training. Just leave it
54
+ max_denoising_steps: 40
55
+ # works great at 1. I do 1 even with my 4090.
56
+ # higher may not work right with newer single batch stacking code anyway
57
+ batch_size: 1
58
+ # bf16 works best if your GPU supports it (modern)
59
+ dtype: bf16 # fp32, bf16, fp16
60
+ # if you have it, use it. It is faster and better
61
+ # torch 2.0 doesnt need xformers anymore, only use if you have lower version
62
+ # xformers: true
63
+ # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
64
+ # although, the way we train sliders is comparative, so it probably won't work anyway
65
+ noise_offset: 0.0
66
+ # noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
67
+
68
+ # the model to train the LoRA network on
69
+ model:
70
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
71
+ name_or_path: "runwayml/stable-diffusion-v1-5"
72
+ is_v2: false # for v2 models
73
+ is_v_pred: false # for v-prediction models (most v2 models)
74
+ # has some issues with the dual text encoder and the way we train sliders
75
+ # it works bit weights need to probably be higher to see it.
76
+ is_xl: false # for SDXL models
77
+
78
+ # saving config
79
+ save:
80
+ dtype: float16 # precision to save. I recommend float16
81
+ save_every: 50 # save every this many steps
82
+ # this will remove step counts more than this number
83
+ # allows you to save more often in case of a crash without filling up your drive
84
+ max_step_saves_to_keep: 2
85
+
86
+ # sampling config
87
+ sample:
88
+ # must match train.noise_scheduler, this is not used here
89
+ # but may be in future and in other processes
90
+ sampler: "ddpm"
91
+ # sample every this many steps
92
+ sample_every: 20
93
+ # image size
94
+ width: 512
95
+ height: 512
96
+ # prompts to use for sampling. Do as many as you want, but it slows down training
97
+ # pick ones that will best represent the concept you are trying to adjust
98
+ # allows some flags after the prompt
99
+ # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
100
+ # slide are good tests. will inherit sample.network_multiplier if not set
101
+ # --n [string] # negative prompt, will inherit sample.neg if not set
102
+ # Only 75 tokens allowed currently
103
+ # I like to do a wide positive and negative spread so I can see a good range and stop
104
+ # early if the network is braking down
105
+ prompts:
106
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
107
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
108
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
109
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
110
+ - "a golden retriever sitting on a leather couch, --m -5"
111
+ - "a golden retriever sitting on a leather couch --m -3"
112
+ - "a golden retriever sitting on a leather couch --m 3"
113
+ - "a golden retriever sitting on a leather couch --m 5"
114
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
115
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
116
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
117
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
118
+ # negative prompt used on all prompts above as default if they don't have one
119
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
120
+ # seed for sampling. 42 is the answer for everything
121
+ seed: 42
122
+ # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
123
+ # will start over on next sample_every so s1 is always seed
124
+ # works well if you use same prompt but want different results
125
+ walk_seed: false
126
+ # cfg scale (4 to 10 is good)
127
+ guidance_scale: 7
128
+ # sampler steps (20 to 30 is good)
129
+ sample_steps: 20
130
+ # default network multiplier for all prompts
131
+ # since we are training a slider, I recommend overriding this with --m [number]
132
+ # in the prompts above to get both sides of the slider
133
+ network_multiplier: 1.0
134
+
135
+ # logging information
136
+ logging:
137
+ log_every: 10 # log every this many steps
138
+ use_wandb: false # not supported yet
139
+ verbose: false # probably done need unless you are debugging
140
+
141
+ # slider training config, best for last
142
+ slider:
143
+ # resolutions to train on. [ width, height ]. This is less important for sliders
144
+ # as we are not teaching the model anything it doesn't already know
145
+ # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
146
+ # and [ 1024, 1024 ] for sd_xl
147
+ # you can do as many as you want here
148
+ resolutions:
149
+ - [ 512, 512 ]
150
+ # - [ 512, 768 ]
151
+ # - [ 768, 768 ]
152
+ # slider training uses 4 combined steps for a single round. This will do it in one gradient
153
+ # step. It is highly optimized and shouldn't take anymore vram than doing without it,
154
+ # since we break down batches for gradient accumulation now. so just leave it on.
155
+ batch_full_slide: true
156
+ # These are the concepts to train on. You can do as many as you want here,
157
+ # but they can conflict outweigh each other. Other than experimenting, I recommend
158
+ # just doing one for good results
159
+ targets:
160
+ # target_class is the base concept we are adjusting the representation of
161
+ # for example, if we are adjusting the representation of a person, we would use "person"
162
+ # if we are adjusting the representation of a cat, we would use "cat" It is not
163
+ # a keyword necessarily but what the model understands the concept to represent.
164
+ # "person" will affect men, women, children, etc but will not affect cats, dogs, etc
165
+ # it is the models base general understanding of the concept and everything it represents
166
+ # you can leave it blank to affect everything. In this example, we are adjusting
167
+ # detail, so we will leave it blank to affect everything
168
+ - target_class: ""
169
+ # positive is the prompt for the positive side of the slider.
170
+ # It is the concept that will be excited and amplified in the model when we slide the slider
171
+ # to the positive side and forgotten / inverted when we slide
172
+ # the slider to the negative side. It is generally best to include the target_class in
173
+ # the prompt. You want it to be the extreme of what you want to train on. For example,
174
+ # if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
175
+ # as the prompt. Not just "fat person"
176
+ # max 75 tokens for now
177
+ positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
178
+ # negative is the prompt for the negative side of the slider and works the same as positive
179
+ # it does not necessarily work the same as a negative prompt when generating images
180
+ # these need to be polar opposites.
181
+ # max 76 tokens for now
182
+ negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
183
+ # the loss for this target is multiplied by this number.
184
+ # if you are doing more than one target it may be good to set less important ones
185
+ # to a lower number like 0.1 so they don't outweigh the primary target
186
+ weight: 1.0
187
+ # shuffle the prompts split by the comma. We will run every combination randomly
188
+ # this will make the LoRA more robust. You probably want this on unless prompt order
189
+ # is important for some reason
190
+ shuffle: true
191
+
192
+
193
+ # anchors are prompts that we will try to hold on to while training the slider
194
+ # these are NOT necessary and can prevent the slider from converging if not done right
195
+ # leave them off if you are having issues, but they can help lock the network
196
+ # on certain concepts to help prevent catastrophic forgetting
197
+ # you want these to generate an image that is not your target_class, but close to it
198
+ # is fine as long as it does not directly overlap it.
199
+ # For example, if you are training on a person smiling,
200
+ # you could use "a person with a face mask" as an anchor. It is a person, the image is the same
201
+ # regardless if they are smiling or not, however, the closer the concept is to the target_class
202
+ # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
203
+ # for close concepts, you want to be closer to 0.1 or 0.2
204
+ # these will slow down training. I am leaving them off for the demo
205
+
206
+ # anchors:
207
+ # - prompt: "a woman"
208
+ # neg_prompt: "animal"
209
+ # # the multiplier applied to the LoRA when this is run.
210
+ # # higher will give it more weight but also help keep the lora from collapsing
211
+ # multiplier: 1.0
212
+ # - prompt: "a man"
213
+ # neg_prompt: "animal"
214
+ # multiplier: 1.0
215
+ # - prompt: "a person"
216
+ # neg_prompt: "animal"
217
+ # multiplier: 1.0
218
+
219
+ # You can put any information you want here, and it will be saved in the model.
220
+ # The below is an example, but you can put your grocery list in it if you want.
221
+ # It is saved in the model so be aware of that. The software will include this
222
+ # plus some other information for you automatically
223
+ meta:
224
+ # [name] gets replaced with the name above
225
+ name: "[name]"
226
+ # version: '1.0'
227
+ # creator:
228
+ # name: Your Name
229
+ # email: your@gmail.com
230
+ # website: https://your.website
ai-toolkit/dgx_instructions.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Toolkit by Ostris
2
+
3
+ ## DGX OS installation instructions
4
+
5
+ You need to use Python 3.11 to run AI Toolkit on DGX OS. The easiest way to do this without affecting the system installation of Python is to create a virtual environment with **miniconda**, which allows you to specify the version of Python to use in the environment.
6
+
7
+ This guide will assume you have a fresh installation of DGX OS, and will guide you through the installation of all requirements.
8
+
9
+ ### Installation instructions for DGX OS:
10
+
11
+ **1) Get Python 3.11 (via miniconda)**
12
+
13
+ Install the latest version of miniconda:
14
+ ```
15
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
16
+ chmod u+x Miniconda3-latest-Linux-aarch64.sh
17
+ ./Miniconda3-latest-Linux-aarch64.sh
18
+ ```
19
+
20
+ Restart your bash or ssh session. If miniconda was installed successfully, it will automatically load the 'base' environment by default. If you want to disable this behaviour, run:
21
+ ```
22
+ conda config --set auto_activate_base false
23
+ ```
24
+
25
+ Now you can create a Python 3.11 environment for ai-toolkit:
26
+ ```
27
+ conda create --name ai-toolkit python=3.11
28
+ ```
29
+
30
+ Then activate the environment with:
31
+
32
+ ```
33
+ conda activate ai-toolkit
34
+ ```
35
+
36
+
37
+ **2) Install PyTorch**
38
+
39
+ ```
40
+ pip3 install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu130
41
+ ```
42
+
43
+
44
+ **3) Install the remaining requirements (dgx_requirements.txt)**
45
+
46
+ ```
47
+ pip3 install -r dgx_requirements.txt
48
+ ```
49
+
50
+ ### Running the UI on DGX OS:
51
+
52
+ Running the UI is not that different from doing it on other systems, however, you need to install the ARM64 version of NodeJS for Linux, which is compatible with the NVIDIA Grace CPU.
53
+
54
+
55
+ **1) Install Node.js**
56
+
57
+ Download a Linux ARM64 build of Node.js from: https://nodejs.org (for example: https://nodejs.org/dist/v24.11.1/node-v24.11.1-linux-arm64.tar.xz)
58
+
59
+ Extract it and add the bin directory to your path. I extracted it to **/opt** and added the following to my ~/.bashrc file:
60
+ ```
61
+ export PATH=“/opt/node-v24.11.1-linux-arm64/bin:$PATH”
62
+ ```
63
+
64
+
65
+ **2) Compile and run the Node.js UI**
66
+
67
+ Change to the ui directory, then build and run the UI:
68
+ ```
69
+ cd ui
70
+ npm run build_and_start
71
+ ```
72
+
73
+ If all went well, you’ll be able to access the UI on port 8675 and start training.
74
+
75
+
76
+ <details>
77
+ <summary>Troubleshooting issues</summary>
78
+ If you’re not getting any output when starting a training job from the UI, it’s probably crashing before the process started, the best way to debug these issues is to run the python training script directly (which is normally started by the UI). To do this, set up a training job in the UI, go to the advanced config screen, copy and paste the configuration into a file like train.yaml, then run the training script like this with the conda virtual environment active:
79
+
80
+ ```
81
+ python run.py path/to/train.yaml
82
+ ```
83
+ </details>
84
+ <br>
ai-toolkit/dgx_requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # You need to use Python 3.11, the easiest way to get this on DGX OS without impacting the system version of Python is to create an environment with miniconda.
2
+
3
+ # specific dependency versions needed on DGX OS devices:
4
+ scipy==1.16.0
5
+ tifffile==2025.6.11
6
+ imageio==2.37.0
7
+ scikit_image==0.25.2
8
+ clean_fid==0.1.35
9
+ pywavelets==1.9.0
10
+ contourpy==1.3.3
11
+ opencv_python_headless==4.11.0.86
12
+
13
+ -r requirements_base.txt
ai-toolkit/docker-compose.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+
3
+ services:
4
+ ai-toolkit:
5
+ image: ostris/aitoolkit:latest
6
+ restart: unless-stopped
7
+ ports:
8
+ - "8675:8675"
9
+ volumes:
10
+ - ~/.cache/huggingface/hub:/root/.cache/huggingface/hub
11
+ - ./aitk_db.db:/app/ai-toolkit/aitk_db.db
12
+ - ./datasets:/app/ai-toolkit/datasets
13
+ - ./output:/app/ai-toolkit/output
14
+ - ./config:/app/ai-toolkit/config
15
+ environment:
16
+ - AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password}
17
+ - NODE_ENV=production
18
+ - TZ=UTC
19
+ deploy:
20
+ resources:
21
+ reservations:
22
+ devices:
23
+ - driver: nvidia
24
+ count: all
25
+ capabilities: [gpu]
ai-toolkit/docker/Dockerfile ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.8.1-devel-ubuntu24.04
2
+
3
+ LABEL authors="jaret"
4
+
5
+ # Set noninteractive to avoid timezone prompts
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ # ref https://en.wikipedia.org/wiki/CUDA
9
+ ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0 10.0 12.0"
10
+
11
+ # Install dependencies
12
+ RUN apt-get update && apt-get install --no-install-recommends -y \
13
+ git \
14
+ curl \
15
+ build-essential \
16
+ cmake \
17
+ wget \
18
+ python3.12 \
19
+ python3-pip \
20
+ python3-dev \
21
+ python3-setuptools \
22
+ python3-wheel \
23
+ python3-venv \
24
+ ffmpeg \
25
+ tmux \
26
+ htop \
27
+ nvtop \
28
+ python3-opencv \
29
+ openssh-client \
30
+ openssh-server \
31
+ openssl \
32
+ rsync \
33
+ unzip \
34
+ && apt-get clean \
35
+ && rm -rf /var/lib/apt/lists/*
36
+
37
+ # Install nodejs
38
+ WORKDIR /tmp
39
+ RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \
40
+ bash nodesource_setup.sh && \
41
+ apt-get update && \
42
+ apt-get install -y nodejs && \
43
+ apt-get clean && \
44
+ rm -rf /var/lib/apt/lists/*
45
+
46
+ WORKDIR /app
47
+
48
+ # Set aliases for python and pip
49
+ RUN ln -s /usr/bin/python3 /usr/bin/python
50
+
51
+ # install pytorch before cache bust to avoid redownloading pytorch
52
+ RUN pip install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages
53
+
54
+ WORKDIR /app/ai-toolkit
55
+
56
+ # ---------------------------------------------------------------------------- #
57
+ # Dependency layers come BEFORE the source clone so they are only rebuilt (and
58
+ # only need to be re-pulled by servers) when the dependency manifests change,
59
+ # not on every code change.
60
+ # ---------------------------------------------------------------------------- #
61
+
62
+ # Install Python dependencies (only re-runs when the requirements files change)
63
+ COPY requirements.txt requirements_base.txt /app/ai-toolkit/
64
+ RUN pip install --no-cache-dir --break-system-packages -r requirements.txt && \
65
+ pip install setuptools==69.5.1 --no-cache-dir --break-system-packages
66
+
67
+ # Install Node dependencies (only re-runs when package.json / package-lock.json change)
68
+ COPY ui/package.json ui/package-lock.json /app/ai-toolkit/ui/
69
+ RUN cd /app/ai-toolkit/ui && npm ci
70
+
71
+ # ---------------------------------------------------------------------------- #
72
+ # Source code comes LAST. Only this layer (plus the UI build below) is rebuilt
73
+ # on a code change, so servers only re-pull the (small) source, not the deps.
74
+ # Clone to a temp dir and rsync the source in, preserving the dependency dirs
75
+ # already populated above (ui/node_modules) and the manifests already used.
76
+ # ---------------------------------------------------------------------------- #
77
+ ARG CACHEBUST=1234
78
+ ARG GIT_COMMIT=main
79
+ RUN echo "Cache bust: ${CACHEBUST}" && \
80
+ git clone https://github.com/ostris/ai-toolkit.git /tmp/ai-toolkit-src && \
81
+ cd /tmp/ai-toolkit-src && \
82
+ git checkout ${GIT_COMMIT} && \
83
+ rsync -a --delete \
84
+ --exclude 'ui/node_modules' \
85
+ --exclude 'requirements.txt' \
86
+ --exclude 'ui/package.json' \
87
+ --exclude 'ui/package-lock.json' \
88
+ /tmp/ai-toolkit-src/ /app/ai-toolkit/ && \
89
+ rm -rf /tmp/ai-toolkit-src
90
+
91
+ # Build UI (re-runs on code change, but reuses the cached node_modules above).
92
+ # update_db runs first because it does `prisma generate`, which creates the
93
+ # @prisma/client types the TS build needs. In the old layout generate happened
94
+ # as a side effect of npm install seeing the schema; now the source arrives
95
+ # after npm ci, so run it explicitly before the build.
96
+ RUN cd /app/ai-toolkit/ui && \
97
+ npm run update_db && \
98
+ npm run build
99
+
100
+ # Expose port (assuming the application runs on port 3000)
101
+ EXPOSE 8675
102
+
103
+ WORKDIR /
104
+
105
+ COPY docker/start.sh /start.sh
106
+ RUN chmod +x /start.sh
107
+
108
+ CMD ["/start.sh"]
ai-toolkit/docker/start.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e # Exit the script if any statement returns a non-true return value
3
+
4
+ # ref https://github.com/runpod/containers/blob/main/container-template/start.sh
5
+
6
+ # ---------------------------------------------------------------------------- #
7
+ # Function Definitions #
8
+ # ---------------------------------------------------------------------------- #
9
+
10
+
11
+ # Setup ssh
12
+ setup_ssh() {
13
+ if [[ $PUBLIC_KEY ]]; then
14
+ echo "Setting up SSH..."
15
+ mkdir -p ~/.ssh
16
+ echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
17
+ chmod 700 -R ~/.ssh
18
+
19
+ if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then
20
+ ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N ''
21
+ echo "RSA key fingerprint:"
22
+ ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub
23
+ fi
24
+
25
+ if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then
26
+ ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N ''
27
+ echo "DSA key fingerprint:"
28
+ ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub
29
+ fi
30
+
31
+ if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then
32
+ ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N ''
33
+ echo "ECDSA key fingerprint:"
34
+ ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub
35
+ fi
36
+
37
+ if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then
38
+ ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N ''
39
+ echo "ED25519 key fingerprint:"
40
+ ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub
41
+ fi
42
+
43
+ service ssh start
44
+
45
+ echo "SSH host keys:"
46
+ for key in /etc/ssh/*.pub; do
47
+ echo "Key: $key"
48
+ ssh-keygen -lf $key
49
+ done
50
+ fi
51
+ }
52
+
53
+ # Export env vars
54
+ export_env_vars() {
55
+ echo "Exporting environment variables..."
56
+ printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment
57
+ echo 'source /etc/rp_environment' >> ~/.bashrc
58
+ }
59
+
60
+ # ---------------------------------------------------------------------------- #
61
+ # Main Program #
62
+ # ---------------------------------------------------------------------------- #
63
+
64
+
65
+ echo "Pod Started"
66
+
67
+ setup_ssh
68
+ export_env_vars
69
+ echo "Starting AI Toolkit UI..."
70
+ cd /app/ai-toolkit/ui && npm run start
ai-toolkit/extensions/example/ExampleMergeModels.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from collections import OrderedDict
4
+ from typing import TYPE_CHECKING
5
+ from jobs.process import BaseExtensionProcess
6
+ from toolkit.config_modules import ModelConfig
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ from toolkit.train_tools import get_torch_dtype
9
+ from tqdm import tqdm
10
+
11
+ # Type check imports. Prevents circular imports
12
+ if TYPE_CHECKING:
13
+ from jobs import ExtensionJob
14
+
15
+
16
+ # extend standard config classes to add weight
17
+ class ModelInputConfig(ModelConfig):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.weight = kwargs.get('weight', 1.0)
21
+ # overwrite default dtype unless user specifies otherwise
22
+ # float 32 will give up better precision on the merging functions
23
+ self.dtype: str = kwargs.get('dtype', 'float32')
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ # this is our main class process
32
+ class ExampleMergeModels(BaseExtensionProcess):
33
+ def __init__(
34
+ self,
35
+ process_id: int,
36
+ job: 'ExtensionJob',
37
+ config: OrderedDict
38
+ ):
39
+ super().__init__(process_id, job, config)
40
+ # this is the setup process, do not do process intensive stuff here, just variable setup and
41
+ # checking requirements. This is called before the run() function
42
+ # no loading models or anything like that, it is just for setting up the process
43
+ # all of your process intensive stuff should be done in the run() function
44
+ # config will have everything from the process item in the config file
45
+
46
+ # convince methods exist on BaseProcess to get config values
47
+ # if required is set to true and the value is not found it will throw an error
48
+ # you can pass a default value to get_conf() as well if it was not in the config file
49
+ # as well as a type to cast the value to
50
+ self.save_path = self.get_conf('save_path', required=True)
51
+ self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
52
+ self.device = self.get_conf('device', default='cpu', as_type=torch.device)
53
+
54
+ # build models to merge list
55
+ models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
56
+ # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
57
+ # this way you can add methods to it and it is easier to read and code. There are a lot of
58
+ # inbuilt config classes located in toolkit.config_modules as well
59
+ self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
60
+ # setup is complete. Don't load anything else here, just setup variables and stuff
61
+
62
+ # this is the entire run process be sure to call super().run() first
63
+ def run(self):
64
+ # always call first
65
+ super().run()
66
+ print(f"Running process: {self.__class__.__name__}")
67
+
68
+ # let's adjust our weights first to normalize them so the total is 1.0
69
+ total_weight = sum([model.weight for model in self.models_to_merge])
70
+ weight_adjust = 1.0 / total_weight
71
+ for model in self.models_to_merge:
72
+ model.weight *= weight_adjust
73
+
74
+ output_model: StableDiffusion = None
75
+ # let's do the merge, it is a good idea to use tqdm to show progress
76
+ for model_config in tqdm(self.models_to_merge, desc="Merging models"):
77
+ # setup model class with our helper class
78
+ sd_model = StableDiffusion(
79
+ device=self.device,
80
+ model_config=model_config,
81
+ dtype="float32"
82
+ )
83
+ # load the model
84
+ sd_model.load_model()
85
+
86
+ # adjust the weight of the text encoder
87
+ if isinstance(sd_model.text_encoder, list):
88
+ # sdxl model
89
+ for text_encoder in sd_model.text_encoder:
90
+ for key, value in text_encoder.state_dict().items():
91
+ value *= model_config.weight
92
+ else:
93
+ # normal model
94
+ for key, value in sd_model.text_encoder.state_dict().items():
95
+ value *= model_config.weight
96
+ # adjust the weights of the unet
97
+ for key, value in sd_model.unet.state_dict().items():
98
+ value *= model_config.weight
99
+
100
+ if output_model is None:
101
+ # use this one as the base
102
+ output_model = sd_model
103
+ else:
104
+ # merge the models
105
+ # text encoder
106
+ if isinstance(output_model.text_encoder, list):
107
+ # sdxl model
108
+ for i, text_encoder in enumerate(output_model.text_encoder):
109
+ for key, value in text_encoder.state_dict().items():
110
+ value += sd_model.text_encoder[i].state_dict()[key]
111
+ else:
112
+ # normal model
113
+ for key, value in output_model.text_encoder.state_dict().items():
114
+ value += sd_model.text_encoder.state_dict()[key]
115
+ # unet
116
+ for key, value in output_model.unet.state_dict().items():
117
+ value += sd_model.unet.state_dict()[key]
118
+
119
+ # remove the model to free memory
120
+ del sd_model
121
+ flush()
122
+
123
+ # merge loop is done, let's save the model
124
+ print(f"Saving merged model to {self.save_path}")
125
+ output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
126
+ print(f"Saved merged model to {self.save_path}")
127
+ # do cleanup here
128
+ del output_model
129
+ flush()
ai-toolkit/extensions/example/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class ExampleMergeExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "example_merge_extension"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Example Merge Extension"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ExampleMergeModels import ExampleMergeModels
19
+ return ExampleMergeModels
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ ExampleMergeExtension
25
+ ]
ai-toolkit/extensions/example/config/config.example.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # Always include at least one example config file to show how to use your extension.
3
+ # use plenty of comments so users know how to use it and what everything does
4
+
5
+ # all extensions will use this job name
6
+ job: extension
7
+ config:
8
+ name: 'my_awesome_merge'
9
+ process:
10
+ # Put your example processes here. This will be passed
11
+ # to your extension process in the config argument.
12
+ # the type MUST match your extension uid
13
+ - type: "example_merge_extension"
14
+ # save path for the merged model
15
+ save_path: "output/merge/[name].safetensors"
16
+ # save type
17
+ dtype: fp16
18
+ # device to run it on
19
+ device: cuda:0
20
+ # input models can only be SD1.x and SD2.x models for this example (currently)
21
+ models_to_merge:
22
+ # weights are relative, total weights will be normalized
23
+ # for example. If you have 2 models with weight 1.0, they will
24
+ # both be weighted 0.5. If you have 1 model with weight 1.0 and
25
+ # another with weight 2.0, the first will be weighted 1/3 and the
26
+ # second will be weighted 2/3
27
+ - name_or_path: "input/model1.safetensors"
28
+ weight: 1.0
29
+ - name_or_path: "input/model2.safetensors"
30
+ weight: 1.0
31
+ - name_or_path: "input/model3.safetensors"
32
+ weight: 0.3
33
+ - name_or_path: "input/model4.safetensors"
34
+ weight: 1.0
35
+
36
+
37
+ # you can put any information you want here, and it will be saved in the model
38
+ # the below is an example. I recommend doing trigger words at a minimum
39
+ # in the metadata. The software will include this plus some other information
40
+ meta:
41
+ name: "[name]" # [name] gets replaced with the name above
42
+ description: A short description of your model
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
+ email: your@email.com
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from collections import OrderedDict
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from diffusers import T2IAdapter
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from torch.utils.data import DataLoader
12
+ from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
13
+ from tqdm import tqdm
14
+
15
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
16
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
17
+ from toolkit.sampler import get_sampler
18
+ from toolkit.stable_diffusion_model import StableDiffusion
19
+ import gc
20
+ import torch
21
+ from jobs.process import BaseExtensionProcess
22
+ from toolkit.data_loader import get_dataloader_from_datasets
23
+ from toolkit.train_tools import get_torch_dtype
24
+ from controlnet_aux.midas import MidasDetector
25
+ from diffusers.utils import load_image
26
+ from torchvision.transforms import ToTensor
27
+
28
+
29
+ def flush():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+
33
+
34
+
35
+
36
+
37
+ class GenerateConfig:
38
+
39
+ def __init__(self, **kwargs):
40
+ self.prompts: List[str]
41
+ self.sampler = kwargs.get('sampler', 'ddpm')
42
+ self.neg = kwargs.get('neg', '')
43
+ self.seed = kwargs.get('seed', -1)
44
+ self.walk_seed = kwargs.get('walk_seed', False)
45
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
46
+ self.sample_steps = kwargs.get('sample_steps', 20)
47
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
48
+ self.ext = kwargs.get('ext', 'png')
49
+ self.denoise_strength = kwargs.get('denoise_strength', 0.5)
50
+ self.trigger_word = kwargs.get('trigger_word', None)
51
+
52
+
53
+ class Img2ImgGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
59
+ self.device = self.get_conf('device', 'cuda')
60
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
61
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
62
+ self.is_latents_cached = True
63
+ raw_datasets = self.get_conf('datasets', None)
64
+ if raw_datasets is not None and len(raw_datasets) > 0:
65
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
66
+ self.datasets = None
67
+ self.datasets_reg = None
68
+ self.dtype = self.get_conf('dtype', 'float16')
69
+ self.torch_dtype = get_torch_dtype(self.dtype)
70
+ self.params = []
71
+ if raw_datasets is not None and len(raw_datasets) > 0:
72
+ for raw_dataset in raw_datasets:
73
+ dataset = DatasetConfig(**raw_dataset)
74
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
75
+ if not is_caching:
76
+ self.is_latents_cached = False
77
+ if dataset.is_reg:
78
+ if self.datasets_reg is None:
79
+ self.datasets_reg = []
80
+ self.datasets_reg.append(dataset)
81
+ else:
82
+ if self.datasets is None:
83
+ self.datasets = []
84
+ self.datasets.append(dataset)
85
+
86
+ self.progress_bar = None
87
+ self.sd = StableDiffusion(
88
+ device=self.device,
89
+ model_config=self.model_config,
90
+ dtype=self.dtype,
91
+ )
92
+ print(f"Using device {self.device}")
93
+ self.data_loader: DataLoader = None
94
+ self.adapter: T2IAdapter = None
95
+
96
+ def to_pil(self, img):
97
+ # image comes in -1 to 1. convert to a PIL RGB image
98
+ img = (img + 1) / 2
99
+ img = img.clamp(0, 1)
100
+ img = img[0].permute(1, 2, 0).cpu().numpy()
101
+ img = (img * 255).astype(np.uint8)
102
+ image = Image.fromarray(img)
103
+ return image
104
+
105
+ def run(self):
106
+ with torch.no_grad():
107
+ super().run()
108
+ print("Loading model...")
109
+ self.sd.load_model()
110
+ device = torch.device(self.device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLImg2ImgPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ ).to(device, dtype=self.torch_dtype)
122
+ elif self.model_config.is_pixart:
123
+ pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
124
+ else:
125
+ raise NotImplementedError("Only XL models are supported")
126
+ pipe.set_progress_bar_config(disable=True)
127
+
128
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
129
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
130
+
131
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
132
+
133
+ num_batches = len(self.data_loader)
134
+ pbar = tqdm(total=num_batches, desc="Generating images")
135
+ seed = self.generate_config.seed
136
+ # load images from datasets, use tqdm
137
+ for i, batch in enumerate(self.data_loader):
138
+ batch: DataLoaderBatchDTO = batch
139
+
140
+ gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
141
+ generator = torch.manual_seed(gen_seed)
142
+
143
+ file_item: FileItemDTO = batch.file_items[0]
144
+ img_path = file_item.path
145
+ img_filename = os.path.basename(img_path)
146
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
147
+ img_filename = img_filename_no_ext + '.' + self.generate_config.ext
148
+ output_path = os.path.join(self.output_folder, img_filename)
149
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
150
+
151
+ if self.copy_inputs_to is not None:
152
+ output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
153
+ output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
154
+ else:
155
+ output_inputs_path = None
156
+ output_inputs_caption_path = None
157
+
158
+ caption = batch.get_caption_list()[0]
159
+ if self.generate_config.trigger_word is not None:
160
+ caption = caption.replace('[trigger]', self.generate_config.trigger_word)
161
+
162
+ img: torch.Tensor = batch.tensor.clone()
163
+ image = self.to_pil(img)
164
+
165
+ # image.save(output_depth_path)
166
+ if self.model_config.is_pixart:
167
+ pipe: PixArtSigmaPipeline = pipe
168
+
169
+ # Encode the full image once
170
+ encoded_image = pipe.vae.encode(
171
+ pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
172
+ if hasattr(encoded_image, "latent_dist"):
173
+ latents = encoded_image.latent_dist.sample(generator)
174
+ elif hasattr(encoded_image, "latents"):
175
+ latents = encoded_image.latents
176
+ else:
177
+ raise AttributeError("Could not access latents of provided encoder_output")
178
+ latents = pipe.vae.config.scaling_factor * latents
179
+
180
+ # latents = self.sd.encode_images(img)
181
+
182
+ # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
183
+ # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
184
+ # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
185
+ # timestep = timestep.to(device, dtype=torch.int32)
186
+ # latent = latent.to(device, dtype=self.torch_dtype)
187
+ # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
188
+ # latent = self.sd.add_noise(latent, noise, timestep)
189
+ # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
190
+ batch_size = 1
191
+ num_images_per_prompt = 1
192
+
193
+ shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
194
+ image.width // pipe.vae_scale_factor)
195
+ noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
196
+
197
+ # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
198
+ num_inference_steps = self.generate_config.sample_steps
199
+ strength = self.generate_config.denoise_strength
200
+ # Get timesteps
201
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
202
+ t_start = max(num_inference_steps - init_timestep, 0)
203
+ pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
204
+ timesteps = pipe.scheduler.timesteps[t_start:]
205
+ timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
206
+ latents = pipe.scheduler.add_noise(latents, noise, timestep)
207
+
208
+ gen_images = pipe.__call__(
209
+ prompt=caption,
210
+ negative_prompt=self.generate_config.neg,
211
+ latents=latents,
212
+ timesteps=timesteps,
213
+ width=image.width,
214
+ height=image.height,
215
+ num_inference_steps=num_inference_steps,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ guidance_scale=self.generate_config.guidance_scale,
218
+ # strength=self.generate_config.denoise_strength,
219
+ use_resolution_binning=False,
220
+ output_type="np"
221
+ ).images[0]
222
+ gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
223
+ gen_images = Image.fromarray(gen_images)
224
+ else:
225
+ pipe: StableDiffusionXLImg2ImgPipeline = pipe
226
+
227
+ gen_images = pipe.__call__(
228
+ prompt=caption,
229
+ negative_prompt=self.generate_config.neg,
230
+ image=image,
231
+ num_inference_steps=self.generate_config.sample_steps,
232
+ guidance_scale=self.generate_config.guidance_scale,
233
+ strength=self.generate_config.denoise_strength,
234
+ ).images[0]
235
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
236
+ gen_images.save(output_path)
237
+
238
+ # save caption
239
+ with open(output_caption_path, 'w') as f:
240
+ f.write(caption)
241
+
242
+ if output_inputs_path is not None:
243
+ os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
244
+ image.save(output_inputs_path)
245
+ with open(output_inputs_caption_path, 'w') as f:
246
+ f.write(caption)
247
+
248
+ pbar.update(1)
249
+ batch.cleanup()
250
+
251
+ pbar.close()
252
+ print("Done generating images")
253
+ # cleanup
254
+ del self.sd
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
5
+ from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
6
+ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ import gc
9
+ import torch
10
+ from jobs.process import BaseExtensionProcess
11
+ from toolkit.train_tools import get_torch_dtype
12
+
13
+
14
+ def flush():
15
+ torch.cuda.empty_cache()
16
+ gc.collect()
17
+
18
+
19
+ class PureLoraGenerator(BaseExtensionProcess):
20
+
21
+ def __init__(self, process_id: int, job, config: OrderedDict):
22
+ super().__init__(process_id, job, config)
23
+ self.output_folder = self.get_conf('output_folder', required=True)
24
+ self.device = self.get_conf('device', 'cuda')
25
+ self.device_torch = torch.device(self.device)
26
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
27
+ self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
28
+ self.dtype = self.get_conf('dtype', 'float16')
29
+ self.torch_dtype = get_torch_dtype(self.dtype)
30
+ lorm_config = self.get_conf('lorm', None)
31
+ self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
32
+
33
+ self.device_state_preset = get_train_sd_device_state_preset(
34
+ device=torch.device(self.device),
35
+ )
36
+
37
+ self.progress_bar = None
38
+ self.sd = StableDiffusion(
39
+ device=self.device,
40
+ model_config=self.model_config,
41
+ dtype=self.dtype,
42
+ )
43
+
44
+ def run(self):
45
+ super().run()
46
+ print("Loading model...")
47
+ with torch.no_grad():
48
+ self.sd.load_model()
49
+ self.sd.unet.eval()
50
+ self.sd.unet.to(self.device_torch)
51
+ if isinstance(self.sd.text_encoder, list):
52
+ for te in self.sd.text_encoder:
53
+ te.eval()
54
+ te.to(self.device_torch)
55
+ else:
56
+ self.sd.text_encoder.eval()
57
+ self.sd.to(self.device_torch)
58
+
59
+ print(f"Converting to LoRM UNet")
60
+ # replace the unet with LoRMUnet
61
+ convert_diffusers_unet_to_lorm(
62
+ self.sd.unet,
63
+ config=self.lorm_config,
64
+ )
65
+
66
+ sample_folder = os.path.join(self.output_folder)
67
+ gen_img_config_list = []
68
+
69
+ sample_config = self.generate_config
70
+ start_seed = sample_config.seed
71
+ current_seed = start_seed
72
+ for i in range(len(sample_config.prompts)):
73
+ if sample_config.walk_seed:
74
+ current_seed = start_seed + i
75
+
76
+ filename = f"[time]_[count].{self.generate_config.ext}"
77
+ output_path = os.path.join(sample_folder, filename)
78
+ prompt = sample_config.prompts[i]
79
+ extra_args = {}
80
+ gen_img_config_list.append(GenerateImageConfig(
81
+ prompt=prompt, # it will autoparse the prompt
82
+ width=sample_config.width,
83
+ height=sample_config.height,
84
+ negative_prompt=sample_config.neg,
85
+ seed=current_seed,
86
+ guidance_scale=sample_config.guidance_scale,
87
+ guidance_rescale=sample_config.guidance_rescale,
88
+ num_inference_steps=sample_config.sample_steps,
89
+ network_multiplier=sample_config.network_multiplier,
90
+ output_path=output_path,
91
+ output_ext=sample_config.ext,
92
+ adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
93
+ **extra_args
94
+ ))
95
+
96
+ # send to be generated
97
+ self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
98
+ print("Done generating images")
99
+ # cleanup
100
+ del self.sd
101
+ gc.collect()
102
+ torch.cuda.empty_cache()
ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from diffusers import T2IAdapter
9
+ from torch.utils.data import DataLoader
10
+ from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
11
+ from tqdm import tqdm
12
+
13
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
14
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
15
+ from toolkit.sampler import get_sampler
16
+ from toolkit.stable_diffusion_model import StableDiffusion
17
+ import gc
18
+ import torch
19
+ from jobs.process import BaseExtensionProcess
20
+ from toolkit.data_loader import get_dataloader_from_datasets
21
+ from toolkit.train_tools import get_torch_dtype
22
+ from controlnet_aux.midas import MidasDetector
23
+ from diffusers.utils import load_image
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ class GenerateConfig:
32
+
33
+ def __init__(self, **kwargs):
34
+ self.prompts: List[str]
35
+ self.sampler = kwargs.get('sampler', 'ddpm')
36
+ self.neg = kwargs.get('neg', '')
37
+ self.seed = kwargs.get('seed', -1)
38
+ self.walk_seed = kwargs.get('walk_seed', False)
39
+ self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
40
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
41
+ self.sample_steps = kwargs.get('sample_steps', 20)
42
+ self.prompt_2 = kwargs.get('prompt_2', None)
43
+ self.neg_2 = kwargs.get('neg_2', None)
44
+ self.prompts = kwargs.get('prompts', None)
45
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
46
+ self.ext = kwargs.get('ext', 'png')
47
+ self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
48
+ if kwargs.get('shuffle', False):
49
+ # shuffle the prompts
50
+ random.shuffle(self.prompts)
51
+
52
+
53
+ class ReferenceGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.device = self.get_conf('device', 'cuda')
59
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
60
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
61
+ self.is_latents_cached = True
62
+ raw_datasets = self.get_conf('datasets', None)
63
+ if raw_datasets is not None and len(raw_datasets) > 0:
64
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
65
+ self.datasets = None
66
+ self.datasets_reg = None
67
+ self.dtype = self.get_conf('dtype', 'float16')
68
+ self.torch_dtype = get_torch_dtype(self.dtype)
69
+ self.params = []
70
+ if raw_datasets is not None and len(raw_datasets) > 0:
71
+ for raw_dataset in raw_datasets:
72
+ dataset = DatasetConfig(**raw_dataset)
73
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
74
+ if not is_caching:
75
+ self.is_latents_cached = False
76
+ if dataset.is_reg:
77
+ if self.datasets_reg is None:
78
+ self.datasets_reg = []
79
+ self.datasets_reg.append(dataset)
80
+ else:
81
+ if self.datasets is None:
82
+ self.datasets = []
83
+ self.datasets.append(dataset)
84
+
85
+ self.progress_bar = None
86
+ self.sd = StableDiffusion(
87
+ device=self.device,
88
+ model_config=self.model_config,
89
+ dtype=self.dtype,
90
+ )
91
+ print(f"Using device {self.device}")
92
+ self.data_loader: DataLoader = None
93
+ self.adapter: T2IAdapter = None
94
+
95
+ def run(self):
96
+ super().run()
97
+ print("Loading model...")
98
+ self.sd.load_model()
99
+ device = torch.device(self.device)
100
+
101
+ if self.generate_config.t2i_adapter_path is not None:
102
+ self.adapter = T2IAdapter.from_pretrained(
103
+ self.generate_config.t2i_adapter_path,
104
+ torch_dtype=self.torch_dtype,
105
+ varient="fp16"
106
+ ).to(device)
107
+
108
+ midas_depth = MidasDetector.from_pretrained(
109
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
110
+ ).to(device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLAdapterPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ adapter=self.adapter,
122
+ ).to(device, dtype=self.torch_dtype)
123
+ else:
124
+ pipe = StableDiffusionAdapterPipeline(
125
+ vae=self.sd.vae,
126
+ unet=self.sd.unet,
127
+ text_encoder=self.sd.text_encoder,
128
+ tokenizer=self.sd.tokenizer,
129
+ scheduler=get_sampler(self.generate_config.sampler),
130
+ safety_checker=None,
131
+ feature_extractor=None,
132
+ requires_safety_checker=False,
133
+ adapter=self.adapter,
134
+ ).to(device, dtype=self.torch_dtype)
135
+ pipe.set_progress_bar_config(disable=True)
136
+
137
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
138
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
139
+
140
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
141
+
142
+ num_batches = len(self.data_loader)
143
+ pbar = tqdm(total=num_batches, desc="Generating images")
144
+ seed = self.generate_config.seed
145
+ # load images from datasets, use tqdm
146
+ for i, batch in enumerate(self.data_loader):
147
+ batch: DataLoaderBatchDTO = batch
148
+
149
+ file_item: FileItemDTO = batch.file_items[0]
150
+ img_path = file_item.path
151
+ img_filename = os.path.basename(img_path)
152
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
153
+ output_path = os.path.join(self.output_folder, img_filename)
154
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
155
+ output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
156
+
157
+ caption = batch.get_caption_list()[0]
158
+
159
+ img: torch.Tensor = batch.tensor.clone()
160
+ # image comes in -1 to 1. convert to a PIL RGB image
161
+ img = (img + 1) / 2
162
+ img = img.clamp(0, 1)
163
+ img = img[0].permute(1, 2, 0).cpu().numpy()
164
+ img = (img * 255).astype(np.uint8)
165
+ image = Image.fromarray(img)
166
+
167
+ width, height = image.size
168
+ min_res = min(width, height)
169
+
170
+ if self.generate_config.walk_seed:
171
+ seed = seed + 1
172
+
173
+ if self.generate_config.seed == -1:
174
+ # random
175
+ seed = random.randint(0, 1000000)
176
+
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
+
180
+ # generate depth map
181
+ image = midas_depth(
182
+ image,
183
+ detect_resolution=min_res, # do 512 ?
184
+ image_resolution=min_res
185
+ )
186
+
187
+ # image.save(output_depth_path)
188
+
189
+ gen_images = pipe(
190
+ prompt=caption,
191
+ negative_prompt=self.generate_config.neg,
192
+ image=image,
193
+ num_inference_steps=self.generate_config.sample_steps,
194
+ adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
195
+ guidance_scale=self.generate_config.guidance_scale,
196
+ ).images[0]
197
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
198
+ gen_images.save(output_path)
199
+
200
+ # save caption
201
+ with open(output_caption_path, 'w') as f:
202
+ f.write(caption)
203
+
204
+ pbar.update(1)
205
+ batch.cleanup()
206
+
207
+ pbar.close()
208
+ print("Done generating images")
209
+ # cleanup
210
+ del self.sd
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
ai-toolkit/extensions_built_in/advanced_generator/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class AdvancedReferenceGeneratorExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "reference_generator"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Reference Generator"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ReferenceGenerator import ReferenceGenerator
19
+ return ReferenceGenerator
20
+
21
+
22
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
23
+ class PureLoraGenerator(Extension):
24
+ # uid must be unique, it is how the extension is identified
25
+ uid = "pure_lora_generator"
26
+
27
+ # name is the name of the extension for printing
28
+ name = "Pure LoRA Generator"
29
+
30
+ # This is where your process class is loaded
31
+ # keep your imports in here so they don't slow down the rest of the program
32
+ @classmethod
33
+ def get_process(cls):
34
+ # import your process class here so it is only loaded when needed and return it
35
+ from .PureLoraGenerator import PureLoraGenerator
36
+ return PureLoraGenerator
37
+
38
+
39
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
40
+ class Img2ImgGeneratorExtension(Extension):
41
+ # uid must be unique, it is how the extension is identified
42
+ uid = "batch_img2img"
43
+
44
+ # name is the name of the extension for printing
45
+ name = "Img2ImgGeneratorExtension"
46
+
47
+ # This is where your process class is loaded
48
+ # keep your imports in here so they don't slow down the rest of the program
49
+ @classmethod
50
+ def get_process(cls):
51
+ # import your process class here so it is only loaded when needed and return it
52
+ from .Img2ImgGenerator import Img2ImgGenerator
53
+ return Img2ImgGenerator
54
+
55
+
56
+ AI_TOOLKIT_EXTENSIONS = [
57
+ # you can put a list of extensions here
58
+ AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
59
+ ]
ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: your@gmail.com
91
+ # website: https://your.website
ai-toolkit/extensions_built_in/audio_models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .ace_step import AceStep15Model, AceStep15XLModel
2
+
3
+ AI_TOOLKIT_MODELS = [
4
+ # put a list of models here
5
+ AceStep15Model,
6
+ AceStep15XLModel,
7
+ ]
ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ace_step_15_model import AceStep15Model, AceStep15XLModel
ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import List, Optional
4
+ import huggingface_hub
5
+ import torch
6
+ from safetensors.torch import load_file, save_file
7
+ from extensions_built_in.audio_models.base_audio_model import BaseAudioModel
8
+ from toolkit.basic import flush
9
+ from toolkit.config_modules import GenerateImageConfig
10
+ from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
11
+ from toolkit.samplers.custom_flowmatch_sampler import (
12
+ CustomFlowMatchEulerDiscreteScheduler,
13
+ )
14
+ from toolkit.util.quantize import get_qtype, quantize, quantize_model
15
+
16
+ from optimum.quanto import freeze
17
+ from .src.model import (
18
+ AceStep15,
19
+ OobleckVAE,
20
+ TextEncoder,
21
+ get_silence_latent,
22
+ load_models,
23
+ )
24
+ from transformers import AutoTokenizer
25
+ from .src.pipeline import AceStep15Pipeline
26
+
27
+ scheduler_config = {
28
+ "num_train_timesteps": 1000,
29
+ "shift": 3.0,
30
+ "use_dynamic_shifting": False,
31
+ }
32
+
33
+ def to_number(str_or_number, default):
34
+ if isinstance(str_or_number, (int, float)):
35
+ return str_or_number
36
+ if str_or_number is None:
37
+ return default
38
+ if str_or_number == "":
39
+ return default
40
+ try:
41
+ return float(str_or_number)
42
+ except ValueError:
43
+ try:
44
+ return int(str_or_number)
45
+ except ValueError as e:
46
+ raise ValueError(f"Could not convert {str_or_number} to a number") from e
47
+
48
+
49
+ def parse_ace_step_caption(text):
50
+ """Parse a tagged caption file back into a dict."""
51
+ import re
52
+
53
+ def tag(name):
54
+ m = re.search(rf"<{name}>(.*?)</{name}>", text, re.DOTALL)
55
+ return m.group(1).strip() if m else ""
56
+
57
+ return {
58
+ "caption": tag("CAPTION"),
59
+ "lyrics": tag("LYRICS"),
60
+ "bpm": to_number(tag("BPM"), 120),
61
+ "keyscale": tag("KEYSCALE"),
62
+ "timesignature": tag("TIMESIGNATURE"),
63
+ "duration": to_number(tag("DURATION"), 1.0),
64
+ "language": tag("LANGUAGE"),
65
+ }
66
+
67
+
68
+ class AceStep15Model(BaseAudioModel):
69
+ arch = "ace_step_15"
70
+ sample_rate = 48000
71
+
72
+ def __init__(
73
+ self,
74
+ device,
75
+ model_config,
76
+ dtype="bf16",
77
+ custom_pipeline=None,
78
+ noise_scheduler=None,
79
+ **kwargs,
80
+ ):
81
+ super().__init__(
82
+ device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
83
+ )
84
+ self.is_flow_matching = True
85
+ self.is_transformer = True
86
+ # self.target_lora_modules = ['AceStep15']
87
+ self.target_lora_modules = ["DiTModel"]
88
+
89
+ # static method to get the noise scheduler
90
+ @staticmethod
91
+ def get_train_scheduler():
92
+ return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
93
+
94
+ def load_model(self):
95
+ dtype = self.torch_dtype
96
+ device = self.device_torch
97
+
98
+ model_path = self.model_config.name_or_path
99
+
100
+ if not os.path.exists(model_path):
101
+ # assume it is a hf repo like org/repo/filename.safetensors
102
+ path_parts = model_path.split("/")
103
+ if len(path_parts) != 3:
104
+ raise ValueError(
105
+ f"Model path {model_path} does not exist and is not a valid Hugging Face repo path"
106
+ )
107
+ model_path = huggingface_hub.hf_hub_download(
108
+ repo_id=f"{path_parts[0]}/{path_parts[1]}",
109
+ filename=path_parts[2],
110
+ )
111
+ # load the models from the single safetensors file
112
+ load_device = device
113
+ if self.model_config.low_vram:
114
+ load_device = "cpu"
115
+
116
+ models = load_models(model_path, device=load_device, dtype=dtype)
117
+
118
+ self.model = models["model"]
119
+
120
+ if self.model_config.quantize:
121
+ self.print_and_status_update("Quantizing Transformer")
122
+ # quantize_model(self, self.model.decoder)
123
+ quantize(self.model, weights=get_qtype(self.model_config.qtype))
124
+ freeze(self.model)
125
+ flush()
126
+
127
+ if self.model_config.low_vram:
128
+ self.print_and_status_update("Moving transformer to CPU")
129
+ self.model.to("cpu")
130
+
131
+
132
+ if (
133
+ self.model_config.layer_offloading
134
+ and self.model_config.layer_offloading_transformer_percent > 0
135
+ ):
136
+ raise NotImplementedError("Layer offloading not yet implemented for AceStep15Model")
137
+
138
+ self.text_encoder = models["text_encoder"]
139
+
140
+ if self.model_config.quantize_te:
141
+ self.print_and_status_update("Quantizing Text Encoder")
142
+ quantize(self.text_encoder, weights=get_qtype(self.model_config.qtype_te))
143
+ freeze(self.text_encoder)
144
+ flush()
145
+
146
+ self.vae = models["vae"]
147
+
148
+ # move back to device
149
+ self.model.to(device)
150
+ self.text_encoder.to(device)
151
+ self.vae.to(device)
152
+ self.tokenizer = models["tokenizer"]
153
+
154
+ self.pipeline = AceStep15Pipeline(
155
+ transformer=self.model,
156
+ vae=self.vae,
157
+ text_encoder=self.text_encoder,
158
+ tokenizer=self.tokenizer,
159
+ scheduler=self.get_train_scheduler(),
160
+ )
161
+ if self.model_config.low_vram:
162
+ self.pipeline.do_tiled_decoding = True
163
+
164
+ def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
165
+ if isinstance(prompt, str):
166
+ prompts = [prompt]
167
+ else:
168
+ prompts = prompt
169
+
170
+ if self.text_encoder.device == torch.device("cpu"):
171
+ self.text_encoder.to(self.device_torch)
172
+ # we need the encoder from the model
173
+ if self.model.encoder.device == torch.device("cpu"):
174
+ self.model.encoder.to(self.device_torch)
175
+
176
+ # the prompt should be json as a string. Try to parse it.
177
+ json_prompts = []
178
+ for p in prompts:
179
+ try:
180
+ json_prompts.append(parse_ace_step_caption(p))
181
+ except json.JSONDecodeError:
182
+ raise ValueError(
183
+ f"Prompt {p} is not a valid JSON string. Prompts must be JSON for this model"
184
+ )
185
+
186
+ if self.pipeline.text_encoder.device == torch.device("cpu"):
187
+ self.pipeline.text_encoder.to(self.device_torch)
188
+
189
+ device = self.text_encoder.device
190
+ dtype = self.text_encoder.dtype
191
+
192
+ batch_pe = None
193
+ # TODO not sure this will allow for proper batching
194
+
195
+ for json_prompt in json_prompts:
196
+ prompt = json_prompt.get("caption", "")
197
+ lyrics = json_prompt.get("lyrics", "")
198
+ bpm = json_prompt.get("bpm", 120)
199
+ key = json_prompt.get("key", "C")
200
+ time_sig = json_prompt.get("time_sig", "4/4")
201
+ duration = json_prompt.get("duration", 10)
202
+ duration = int(duration) if isinstance(duration, (int, float)) else 10
203
+ language = json_prompt.get("language", "en")
204
+
205
+ text_embeddings, text_mask, lyric_embeddings, lyric_mask = (
206
+ self.pipeline.get_text_embedings(
207
+ prompt, lyrics, bpm, key, time_sig, duration, language
208
+ )
209
+ )
210
+ latent_len = int(duration * self.pipeline.LATENT_RATE)
211
+ # Silence as source latent [1, 64, T] -> [1, T, 64] for DiT
212
+ sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T]
213
+ src = sil.transpose(1, 2) # [1, T, 64]
214
+ chunk_masks = torch.ones_like(src)
215
+
216
+ # Reference audio (silence)
217
+ ref = sil[:, :, :750].transpose(1, 2) # [1, 750, 64]
218
+ ref_order = torch.zeros(1, device=device, dtype=torch.long)
219
+ enc_h, enc_m, _ = self.pipeline.transformer.prepare_condition(
220
+ text_embeddings,
221
+ text_mask,
222
+ lyric_embeddings,
223
+ lyric_mask,
224
+ ref,
225
+ ref_order,
226
+ src,
227
+ chunk_masks,
228
+ )
229
+
230
+ pe = PromptEmbeds(enc_h, attention_mask=enc_m)
231
+ if batch_pe is None:
232
+ batch_pe = pe
233
+ else:
234
+ batch_pe = concat_prompt_embeds(batch_pe, pe)
235
+ return batch_pe
236
+
237
+ def get_transformer_block_names(self) -> Optional[List[str]]:
238
+ return ["layers"]
239
+
240
+ def get_generation_pipeline(self):
241
+ return self.pipeline
242
+
243
+ def generate_single_audio(
244
+ self,
245
+ pipeline,
246
+ gen_config: GenerateImageConfig,
247
+ conditional_embeds: PromptEmbeds,
248
+ unconditional_embeds: PromptEmbeds,
249
+ generator: torch.Generator,
250
+ extra: dict,
251
+ ):
252
+ if self.model.device == torch.device("cpu"):
253
+ self.model.to(self.device_torch)
254
+ # make sure gen config is setup for audio
255
+ if gen_config.output_ext not in ['mp3', 'wav']:
256
+ gen_config.output_ext = 'mp3'
257
+ prompt = gen_config.prompt
258
+ json_prompt = parse_ace_step_caption(prompt)
259
+ prompt = json_prompt.get("caption", "")
260
+ lyrics = json_prompt.get("lyrics", "")
261
+ bpm = json_prompt.get("bpm", 120)
262
+ key = json_prompt.get("key", "C")
263
+ time_sig = json_prompt.get("time_sig", "4/4")
264
+ duration = json_prompt.get("duration", 0)
265
+ language = json_prompt.get("language", "en")
266
+
267
+ output = self.pipeline(
268
+ prompt=None, # we are passing in the embeds directly, so no need for a prompt
269
+ encoder_embeddings=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.torch_dtype),
270
+ encoder_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.bool),
271
+ num_inference_steps=gen_config.num_inference_steps,
272
+ duration=duration,
273
+ generator=generator,
274
+ bpm=bpm,
275
+ key=key,
276
+ time_sig=time_sig,
277
+ language=language,
278
+ guidance_scale=gen_config.guidance_scale,
279
+ )
280
+ return output
281
+
282
+ def get_noise_prediction(
283
+ self,
284
+ latent_model_input: torch.Tensor, #(1, 300, 64)
285
+ timestep: torch.Tensor, # 0 to 1000 scale
286
+ text_embeddings: PromptEmbeds,
287
+ **kwargs,
288
+ ):
289
+ if self.model.decoder.device == torch.device("cpu"):
290
+ self.model.decoder.to(self.device_torch)
291
+ with torch.no_grad():
292
+ model: AceStep15 = self.model
293
+ tt = timestep.to(self.device_torch, dtype=torch.long) / 1000
294
+ latent_len = latent_model_input.shape[1]
295
+ device = self.device_torch
296
+ dtype = self.torch_dtype
297
+ attn = torch.ones(1, latent_len, device=device, dtype=dtype)
298
+
299
+ # build context from silence latent matching the actual input length
300
+ sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T]
301
+ src = sil.transpose(1, 2) # [1, T, 64]
302
+ chunk_masks = torch.ones_like(src)
303
+ context = torch.cat([src, chunk_masks], dim=-1) # [1, T, 128]
304
+
305
+ pred = model.decoder(
306
+ x=latent_model_input.detach(),
307
+ timestep=tt.detach(),
308
+ timestep_r=tt.detach(),
309
+ attention_mask=attn.detach(),
310
+ enc_h=text_embeddings.text_embeds.to(self.device_torch, dtype=self.torch_dtype).detach(),
311
+ enc_m=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.bool).detach(),
312
+ context=context.detach(),
313
+ )
314
+ return pred
315
+
316
+ def get_loss_target(self, *args, **kwargs):
317
+ noise = kwargs.get("noise")
318
+ batch = kwargs.get("batch")
319
+ return (noise - batch.latents).detach()
320
+
321
+ def encode_audio(self, audio_tensor: torch.Tensor, device=None, dtype=None):
322
+ if device is None:
323
+ device = self.device_torch
324
+ if dtype is None:
325
+ dtype = self.torch_dtype
326
+ if self.vae.device == torch.device("cpu"):
327
+ self.vae.to(device)
328
+ output = self.vae.encode(audio_tensor.to(device=device, dtype=dtype))
329
+ # transpose from [B, 64, T] to [B, T, 64] for DiT
330
+ output = output.transpose(1, 2).contiguous()
331
+ return output
332
+
333
+
334
+ class AceStep15XLModel(AceStep15Model):
335
+ arch = "ace_step_15_xl"
ai-toolkit/extensions_built_in/audio_models/ace_step/src/__init__.py ADDED
File without changes